!--------------------------------------------------------------------------------------------------!
!   CP2K: A general program to perform molecular dynamics simulations                              !
!   Copyright 2000-2025 CP2K developers group <https://cp2k.org>                                   !
!                                                                                                  !
!   SPDX-License-Identifier: GPL-2.0-or-later                                                      !
!--------------------------------------------------------------------------------------------------!

! **************************************************************************************************
!> \brief generate the tasks lists used by collocate and integrate routines
!> \par History
!>      01.2008 [Joost VandeVondele] refactered out of qs_collocate / qs_integrate
!> \author Joost VandeVondele
! **************************************************************************************************
MODULE task_list_methods
   USE offload_api, ONLY: offload_create_buffer, offload_buffer_type
   USE grid_api, ONLY: grid_create_basis_set, grid_create_task_list
   USE ao_util, ONLY: exp_radius_very_extended
   USE basis_set_types, ONLY: get_gto_basis_set, &
                              gto_basis_set_p_type, &
                              gto_basis_set_type
   USE cell_types, ONLY: cell_type, &
                         pbc
   USE cp_control_types, ONLY: dft_control_type
   USE cube_utils, ONLY: compute_cube_center, &
                         cube_info_type, &
                         return_cube, &
                         return_cube_nonortho
   USE cp_dbcsr_api, ONLY: dbcsr_convert_sizes_to_offsets, &
                           dbcsr_finalize, &
                           dbcsr_get_block_p, &
                           dbcsr_get_info, &
                           dbcsr_p_type, &
                           dbcsr_put_block, &
                           dbcsr_type, &
                           dbcsr_work_create
   USE gaussian_gridlevels, ONLY: gaussian_gridlevel, &
                                  gridlevel_info_type
   USE kinds, ONLY: default_string_length, &
                    dp, &
                    int_8
   USE kpoint_types, ONLY: get_kpoint_info, &
                           kpoint_type
   USE memory_utilities, ONLY: reallocate
   USE message_passing, ONLY: &
      mp_comm_type
   USE particle_types, ONLY: particle_type
   USE particle_methods, ONLY: get_particle_set
   USE pw_env_types, ONLY: pw_env_get, &
                           pw_env_type
   USE qs_kind_types, ONLY: get_qs_kind, &
                            qs_kind_type
   USE qs_ks_types, ONLY: get_ks_env, &
                          qs_ks_env_type
   USE qs_neighbor_list_types, ONLY: neighbor_list_set_p_type
   USE realspace_grid_types, ONLY: realspace_grid_desc_p_type, &
                                   realspace_grid_desc_type, &
                                   rs_grid_create, &
                                   rs_grid_locate_rank, &
                                   rs_grid_release, &
                                   rs_grid_reorder_ranks, realspace_grid_type
   USE task_list_types, ONLY: deserialize_task, &
                              reallocate_tasks, &
                              serialize_task, &
                              task_list_type, &
                              atom_pair_type, &
                              task_size_in_int8, &
                              task_type
   USE util, ONLY: sort

!$ USE OMP_LIB, ONLY: omp_destroy_lock, omp_get_num_threads, omp_init_lock, &
!$                    omp_lock_kind, omp_set_lock, omp_unset_lock, omp_get_max_threads
#include "./base/base_uses.f90"

   #:include './common/array_sort.fypp'

   IMPLICIT NONE

   LOGICAL, PRIVATE, PARAMETER :: debug_this_module = .FALSE.

   PRIVATE

   CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'task_list_methods'

   PUBLIC :: generate_qs_task_list, &
             task_list_inner_loop
   PUBLIC :: distribute_tasks, &
             rs_distribute_matrix, &
             rs_scatter_matrices, &
             rs_gather_matrices, &
             rs_copy_to_buffer, &
             rs_copy_to_matrices

CONTAINS

! **************************************************************************************************
!> \brief ...
!> \param ks_env ...
!> \param task_list ...
!> \param basis_type ...
!> \param reorder_rs_grid_ranks Flag that indicates if this routine should
!>        or should not overwrite the rs descriptor (see comment below)
!> \param skip_load_balance_distributed ...
!> \param pw_env_external ...
!> \param sab_orb_external ...
!> \par History
!>      01.2008 factored out of calculate_rho_elec [Joost VandeVondele]
!>      04.2010 divides tasks into grid levels and atom pairs for integrate/collocate [Iain Bethune]
!>              (c) The Numerical Algorithms Group (NAG) Ltd, 2010 on behalf of the HECToR project
!>      06.2015 adjusted to be used with multiple images (k-points) [JGH]
!> \note  If this routine is called several times with different task lists,
!>        the default behaviour is to re-optimize the grid ranks and overwrite
!>        the rs descriptor and grids. reorder_rs_grid_ranks = .FALSE. prevents the code
!>        of performing a new optimization by leaving the rank order in
!>        its current state.
! **************************************************************************************************
   SUBROUTINE generate_qs_task_list(ks_env, task_list, basis_type, &
                                    reorder_rs_grid_ranks, skip_load_balance_distributed, &
                                    pw_env_external, sab_orb_external)

      TYPE(qs_ks_env_type), POINTER                      :: ks_env
      TYPE(task_list_type), POINTER                      :: task_list
      CHARACTER(LEN=*), INTENT(IN)                       :: basis_type
      LOGICAL, INTENT(IN)                                :: reorder_rs_grid_ranks, &
                                                            skip_load_balance_distributed
      TYPE(pw_env_type), OPTIONAL, POINTER               :: pw_env_external
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         OPTIONAL, POINTER                               :: sab_orb_external

      CHARACTER(LEN=*), PARAMETER :: routineN = 'generate_qs_task_list'
      INTEGER, PARAMETER                                 :: max_tasks = 2000

      INTEGER :: cindex, curr_tasks, handle, i, iatom, iatom_old, igrid_level, igrid_level_old, &
                 ikind, ilevel, img, img_old, ipair, ipgf, iset, itask, jatom, jatom_old, jkind, jpgf, &
                 jset, maxpgf, maxset, natoms, nimages, nkind, nseta, nsetb, slot
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: blocks
      INTEGER, DIMENSION(3)                              :: cellind
      INTEGER, DIMENSION(:), POINTER                     :: la_max, la_min, lb_max, lb_min, npgfa, &
                                                            npgfb, nsgf
      INTEGER, DIMENSION(:, :, :), POINTER               :: cell_to_index
      LOGICAL                                            :: dokp
      REAL(KIND=dp)                                      :: kind_radius_a, kind_radius_b
      REAL(KIND=dp), DIMENSION(3)                        :: ra, rab
      REAL(KIND=dp), DIMENSION(:), POINTER               :: set_radius_a, set_radius_b
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: rpgfa, rpgfb, zeta, zetb
      TYPE(cell_type), POINTER                           :: cell
      TYPE(cube_info_type), DIMENSION(:), POINTER        :: cube_info
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(gridlevel_info_type), POINTER                 :: gridlevel_info
      TYPE(gto_basis_set_p_type), DIMENSION(:), POINTER  :: basis_set_list
      TYPE(gto_basis_set_type), POINTER                  :: basis_set_a, basis_set_b, orb_basis_set
      TYPE(kpoint_type), POINTER                         :: kpoints
      TYPE(neighbor_list_set_p_type), DIMENSION(:), &
         POINTER                                         :: sab_orb
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(pw_env_type), POINTER                         :: pw_env
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      TYPE(qs_kind_type), POINTER                        :: qs_kind
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      TYPE(realspace_grid_type), DIMENSION(:), POINTER :: rs_grids
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks

      CALL timeset(routineN, handle)

      CALL get_ks_env(ks_env, &
                      qs_kind_set=qs_kind_set, &
                      cell=cell, &
                      particle_set=particle_set, &
                      dft_control=dft_control)

      CALL get_ks_env(ks_env, sab_orb=sab_orb)
      IF (PRESENT(sab_orb_external)) sab_orb => sab_orb_external

      CALL get_ks_env(ks_env, pw_env=pw_env)
      IF (PRESENT(pw_env_external)) pw_env => pw_env_external
      CALL pw_env_get(pw_env, rs_descs=rs_descs, rs_grids=rs_grids)

      ! *** assign from pw_env
      gridlevel_info => pw_env%gridlevel_info
      cube_info => pw_env%cube_info

      ! find maximum numbers
      nkind = SIZE(qs_kind_set)
      natoms = SIZE(particle_set)
      maxset = 0
      maxpgf = 0
      DO ikind = 1, nkind
         qs_kind => qs_kind_set(ikind)
         CALL get_qs_kind(qs_kind=qs_kind, &
                          basis_set=orb_basis_set, basis_type=basis_type)

         IF (.NOT. ASSOCIATED(orb_basis_set)) CYCLE
         CALL get_gto_basis_set(gto_basis_set=orb_basis_set, npgf=npgfa, nset=nseta)

         maxset = MAX(nseta, maxset)
         maxpgf = MAX(MAXVAL(npgfa), maxpgf)
      END DO

      ! kpoint related
      nimages = dft_control%nimages
      IF (nimages > 1) THEN
         dokp = .TRUE.
         NULLIFY (kpoints)
         CALL get_ks_env(ks_env=ks_env, kpoints=kpoints)
         CALL get_kpoint_info(kpoint=kpoints, cell_to_index=cell_to_index)
      ELSE
         dokp = .FALSE.
         NULLIFY (cell_to_index)
      END IF

      ! free the atom_pair lists if allocated
      IF (ASSOCIATED(task_list%atom_pair_send)) DEALLOCATE (task_list%atom_pair_send)
      IF (ASSOCIATED(task_list%atom_pair_recv)) DEALLOCATE (task_list%atom_pair_recv)

      ! construct a list of all tasks
      IF (.NOT. ASSOCIATED(task_list%tasks)) THEN
         CALL reallocate_tasks(task_list%tasks, max_tasks)
      END IF
      task_list%ntasks = 0
      curr_tasks = SIZE(task_list%tasks)

      ALLOCATE (basis_set_list(nkind))
      DO ikind = 1, nkind
         qs_kind => qs_kind_set(ikind)
         CALL get_qs_kind(qs_kind=qs_kind, basis_set=basis_set_a, &
                          basis_type=basis_type)
         IF (ASSOCIATED(basis_set_a)) THEN
            basis_set_list(ikind)%gto_basis_set => basis_set_a
         ELSE
            NULLIFY (basis_set_list(ikind)%gto_basis_set)
         END IF
      END DO
!!$OMP PARALLEL DEFAULT(NONE) &
!!$OMP SHARED (sab_orb, dokp, basis_set_list, task_list, rs_descs, dft_control, cube_info, gridlevel_info,  &
!!$OMP         curr_tasks, maxpgf, maxset, natoms, nimages, particle_set, cell_to_index, cell) &
!!$OMP PRIVATE (ikind, jkind, iatom, jatom, rab, cellind, basis_set_a, basis_set_b, ra, &
!!$OMP          la_max, la_min, npgfa, nseta, rpgfa, set_radius_a, kind_radius_a, zeta, &
!!$OMP          lb_max, lb_min, npgfb, nsetb, rpgfb, set_radius_b, kind_radius_b, zetb, &
!!$OMP          cindex, slot)
      ! Loop over neighbor list
!!$OMP DO SCHEDULE(GUIDED)
      DO slot = 1, sab_orb(1)%nl_size
         ikind = sab_orb(1)%nlist_task(slot)%ikind
         jkind = sab_orb(1)%nlist_task(slot)%jkind
         iatom = sab_orb(1)%nlist_task(slot)%iatom
         jatom = sab_orb(1)%nlist_task(slot)%jatom
         rab(1:3) = sab_orb(1)%nlist_task(slot)%r(1:3)
         cellind(1:3) = sab_orb(1)%nlist_task(slot)%cell(1:3)

         basis_set_a => basis_set_list(ikind)%gto_basis_set
         IF (.NOT. ASSOCIATED(basis_set_a)) CYCLE
         basis_set_b => basis_set_list(jkind)%gto_basis_set
         IF (.NOT. ASSOCIATED(basis_set_b)) CYCLE
         ra(:) = pbc(particle_set(iatom)%r, cell)
         ! basis ikind
         la_max => basis_set_a%lmax
         la_min => basis_set_a%lmin
         npgfa => basis_set_a%npgf
         nseta = basis_set_a%nset
         rpgfa => basis_set_a%pgf_radius
         set_radius_a => basis_set_a%set_radius
         kind_radius_a = basis_set_a%kind_radius
         zeta => basis_set_a%zet
         ! basis jkind
         lb_max => basis_set_b%lmax
         lb_min => basis_set_b%lmin
         npgfb => basis_set_b%npgf
         nsetb = basis_set_b%nset
         rpgfb => basis_set_b%pgf_radius
         set_radius_b => basis_set_b%set_radius
         kind_radius_b = basis_set_b%kind_radius
         zetb => basis_set_b%zet

         IF (dokp) THEN
            cindex = cell_to_index(cellind(1), cellind(2), cellind(3))
         ELSE
            cindex = 1
         END IF

         CALL task_list_inner_loop(task_list%tasks, task_list%ntasks, curr_tasks, &
                                   rs_descs, dft_control, cube_info, gridlevel_info, cindex, &
                                   iatom, jatom, rpgfa, rpgfb, zeta, zetb, kind_radius_b, &
                                   set_radius_a, set_radius_b, ra, rab, &
                                   la_max, la_min, lb_max, lb_min, npgfa, npgfb, nseta, nsetb)

      END DO
!!$OMP END PARALLEL

      ! redistribute the task list so that all tasks map on the local rs grids
      CALL distribute_tasks( &
         rs_descs=rs_descs, ntasks=task_list%ntasks, natoms=natoms, &
         tasks=task_list%tasks, atom_pair_send=task_list%atom_pair_send, &
         atom_pair_recv=task_list%atom_pair_recv, symmetric=.TRUE., &
         reorder_rs_grid_ranks=reorder_rs_grid_ranks, &
         skip_load_balance_distributed=skip_load_balance_distributed)

      ! compute offsets for rs_scatter_matrix / rs_copy_matrix
      ALLOCATE (nsgf(natoms))
      CALL get_particle_set(particle_set, qs_kind_set, basis=basis_set_list, nsgf=nsgf)
      IF (ASSOCIATED(task_list%atom_pair_send)) THEN
         ! only needed when there is a distributed grid
         CALL rs_calc_offsets(pairs=task_list%atom_pair_send, &
                              nsgf=nsgf, &
                              group_size=rs_descs(1)%rs_desc%group_size, &
                              pair_offsets=task_list%pair_offsets_send, &
                              rank_offsets=task_list%rank_offsets_send, &
                              rank_sizes=task_list%rank_sizes_send, &
                              buffer_size=task_list%buffer_size_send)
      END IF
      CALL rs_calc_offsets(pairs=task_list%atom_pair_recv, &
                           nsgf=nsgf, &
                           group_size=rs_descs(1)%rs_desc%group_size, &
                           pair_offsets=task_list%pair_offsets_recv, &
                           rank_offsets=task_list%rank_offsets_recv, &
                           rank_sizes=task_list%rank_sizes_recv, &
                           buffer_size=task_list%buffer_size_recv)
      DEALLOCATE (basis_set_list, nsgf)

      ! If the rank order has changed, reallocate any of the distributed rs_grids
      IF (reorder_rs_grid_ranks) THEN
         DO i = 1, gridlevel_info%ngrid_levels
            IF (rs_descs(i)%rs_desc%distributed) THEN
               CALL rs_grid_release(rs_grids(i))
               CALL rs_grid_create(rs_grids(i), rs_descs(i)%rs_desc)
            END IF
         END DO
      END IF

      CALL create_grid_task_list(task_list=task_list, &
                                 qs_kind_set=qs_kind_set, &
                                 particle_set=particle_set, &
                                 cell=cell, &
                                 basis_type=basis_type, &
                                 rs_grids=rs_grids)

      ! Now we have the final list of tasks, setup the task_list with the
      ! data needed for the loops in integrate_v/calculate_rho

      IF (ASSOCIATED(task_list%taskstart)) THEN
         DEALLOCATE (task_list%taskstart)
      END IF
      IF (ASSOCIATED(task_list%taskstop)) THEN
         DEALLOCATE (task_list%taskstop)
      END IF
      IF (ASSOCIATED(task_list%npairs)) THEN
         DEALLOCATE (task_list%npairs)
      END IF

      ! First, count the number of unique atom pairs per grid level

      ALLOCATE (task_list%npairs(SIZE(rs_descs)))

      iatom_old = -1; jatom_old = -1; igrid_level_old = -1; img_old = -1
      ipair = 0
      task_list%npairs = 0

      DO i = 1, task_list%ntasks
         igrid_level = task_list%tasks(i)%grid_level
         img = task_list%tasks(i)%image
         iatom = task_list%tasks(i)%iatom
         jatom = task_list%tasks(i)%jatom
         iset = task_list%tasks(i)%iset
         jset = task_list%tasks(i)%jset
         ipgf = task_list%tasks(i)%ipgf
         jpgf = task_list%tasks(i)%jpgf
         IF (igrid_level /= igrid_level_old) THEN
            IF (igrid_level_old /= -1) THEN
               task_list%npairs(igrid_level_old) = ipair
            END IF
            ipair = 1
            igrid_level_old = igrid_level
            iatom_old = iatom
            jatom_old = jatom
            img_old = img
         ELSE IF (iatom /= iatom_old .OR. jatom /= jatom_old .OR. img /= img_old) THEN
            ipair = ipair + 1
            iatom_old = iatom
            jatom_old = jatom
            img_old = img
         END IF
      END DO
      ! Take care of the last iteration
      IF (task_list%ntasks /= 0) THEN
         task_list%npairs(igrid_level) = ipair
      END IF

      ! Second, for each atom pair, find the indices in the task list
      ! of the first and last task

      ! Array sized for worst case
      ALLOCATE (task_list%taskstart(MAXVAL(task_list%npairs), SIZE(rs_descs)))
      ALLOCATE (task_list%taskstop(MAXVAL(task_list%npairs), SIZE(rs_descs)))

      iatom_old = -1; jatom_old = -1; igrid_level_old = -1; img_old = -1
      ipair = 0
      task_list%taskstart = 0
      task_list%taskstop = 0

      DO i = 1, task_list%ntasks
         igrid_level = task_list%tasks(i)%grid_level
         img = task_list%tasks(i)%image
         iatom = task_list%tasks(i)%iatom
         jatom = task_list%tasks(i)%jatom
         iset = task_list%tasks(i)%iset
         jset = task_list%tasks(i)%jset
         ipgf = task_list%tasks(i)%ipgf
         jpgf = task_list%tasks(i)%jpgf
         IF (igrid_level /= igrid_level_old) THEN
            IF (igrid_level_old /= -1) THEN
               task_list%taskstop(ipair, igrid_level_old) = i - 1
            END IF
            ipair = 1
            task_list%taskstart(ipair, igrid_level) = i
            igrid_level_old = igrid_level
            iatom_old = iatom
            jatom_old = jatom
            img_old = img
         ELSE IF (iatom /= iatom_old .OR. jatom /= jatom_old .OR. img /= img_old) THEN
            ipair = ipair + 1
            task_list%taskstart(ipair, igrid_level) = i
            task_list%taskstop(ipair - 1, igrid_level) = i - 1
            iatom_old = iatom
            jatom_old = jatom
            img_old = img
         END IF
      END DO
      ! Take care of the last iteration
      IF (task_list%ntasks /= 0) THEN
         task_list%taskstop(ipair, igrid_level) = task_list%ntasks
      END IF

      ! Debug task destribution
      IF (debug_this_module) THEN
         tasks => task_list%tasks
         WRITE (6, *)
         WRITE (6, *) "Total number of tasks              ", task_list%ntasks
         DO igrid_level = 1, gridlevel_info%ngrid_levels
            WRITE (6, *) "Total number of pairs(grid_level)  ", &
               igrid_level, task_list%npairs(igrid_level)
         END DO
         WRITE (6, *)

         DO igrid_level = 1, gridlevel_info%ngrid_levels

            ALLOCATE (blocks(natoms, natoms, nimages))
            blocks = -1
            DO ipair = 1, task_list%npairs(igrid_level)
               itask = task_list%taskstart(ipair, igrid_level)
               ilevel = task_list%tasks(itask)%grid_level
               img = task_list%tasks(itask)%image
               iatom = task_list%tasks(itask)%iatom
               jatom = task_list%tasks(itask)%jatom
               iset = task_list%tasks(itask)%iset
               jset = task_list%tasks(itask)%jset
               ipgf = task_list%tasks(itask)%ipgf
               jpgf = task_list%tasks(itask)%jpgf
               IF (blocks(iatom, jatom, img) == -1 .AND. blocks(jatom, iatom, img) == -1) THEN
                  blocks(iatom, jatom, img) = 1
                  blocks(jatom, iatom, img) = 1
               ELSE
                  WRITE (6, *) "TASK LIST CONFLICT IN PAIR       ", ipair
                  WRITE (6, *) "Reuse of iatom, jatom, image     ", iatom, jatom, img
               END IF

               iatom_old = iatom
               jatom_old = jatom
               img_old = img
               DO itask = task_list%taskstart(ipair, igrid_level), task_list%taskstop(ipair, igrid_level)
                  ilevel = task_list%tasks(itask)%grid_level
                  img = task_list%tasks(itask)%image
                  iatom = task_list%tasks(itask)%iatom
                  jatom = task_list%tasks(itask)%jatom
                  iset = task_list%tasks(itask)%iset
                  jset = task_list%tasks(itask)%jset
                  ipgf = task_list%tasks(itask)%ipgf
                  jpgf = task_list%tasks(itask)%jpgf
                  IF (iatom /= iatom_old .OR. jatom /= jatom_old .OR. img /= img_old) THEN
                     WRITE (6, *) "TASK LIST CONFLICT IN TASK       ", itask
                     WRITE (6, *) "Inconsistent iatom, jatom, image ", iatom, jatom, img
                     WRITE (6, *) "Should be    iatom, jatom, image ", iatom_old, jatom_old, img_old
                  END IF

               END DO
            END DO
            DEALLOCATE (blocks)

         END DO

      END IF

      CALL timestop(handle)

   END SUBROUTINE generate_qs_task_list

! **************************************************************************************************
!> \brief Sends the task list data to the grid API.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE create_grid_task_list(task_list, qs_kind_set, particle_set, cell, basis_type, rs_grids)
      TYPE(task_list_type), POINTER                      :: task_list
      TYPE(qs_kind_type), DIMENSION(:), POINTER          :: qs_kind_set
      TYPE(particle_type), DIMENSION(:), POINTER         :: particle_set
      TYPE(cell_type), POINTER                           :: cell
      CHARACTER(LEN=*)                                   :: basis_type
      TYPE(realspace_grid_type), DIMENSION(:), POINTER :: rs_grids

      TYPE(gto_basis_set_type), POINTER                  :: orb_basis_set
      INTEGER                                            :: nset, natoms, nkinds, ntasks, &
                                                            ikind, iatom, itask, nsgf
      INTEGER, DIMENSION(:), ALLOCATABLE                 :: atom_kinds, level_list, iatom_list, jatom_list, &
                                                            iset_list, jset_list, ipgf_list, jpgf_list, &
                                                            border_mask_list, block_num_list
      REAL(KIND=dp), DIMENSION(:), ALLOCATABLE           :: radius_list
      REAL(KIND=dp), DIMENSION(:, :), ALLOCATABLE        :: rab_list, atom_positions
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER, DIMENSION(:, :), POINTER                  :: first_sgf
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: sphi, zet
      INTEGER, DIMENSION(:), POINTER                     :: lmax, lmin, npgf, nsgf_set

      nkinds = SIZE(qs_kind_set)
      natoms = SIZE(particle_set)
      ntasks = task_list%ntasks
      tasks => task_list%tasks

      IF (.NOT. ASSOCIATED(task_list%grid_basis_sets)) THEN
         ! Basis sets do not change during simulation - only need to create them once.
         ALLOCATE (task_list%grid_basis_sets(nkinds))
         DO ikind = 1, nkinds
            CALL get_qs_kind(qs_kind_set(ikind), basis_type=basis_type, basis_set=orb_basis_set)
            CALL get_gto_basis_set(gto_basis_set=orb_basis_set, &
                                   nset=nset, &
                                   nsgf=nsgf, &
                                   nsgf_set=nsgf_set, &
                                   npgf=npgf, &
                                   first_sgf=first_sgf, &
                                   lmax=lmax, &
                                   lmin=lmin, &
                                   sphi=sphi, &
                                   zet=zet)
            CALL grid_create_basis_set(nset=nset, &
                                       nsgf=nsgf, &
                                       maxco=SIZE(sphi, 1), &
                                       maxpgf=SIZE(zet, 1), &
                                       lmin=lmin, &
                                       lmax=lmax, &
                                       npgf=npgf, &
                                       nsgf_set=nsgf_set, &
                                       first_sgf=first_sgf, &
                                       sphi=sphi, &
                                       zet=zet, &
                                       basis_set=task_list%grid_basis_sets(ikind))
         END DO
      END IF

      ! Pack task list infos
      ALLOCATE (atom_kinds(natoms), atom_positions(3, natoms))
      DO iatom = 1, natoms
         atom_kinds(iatom) = particle_set(iatom)%atomic_kind%kind_number
         atom_positions(:, iatom) = pbc(particle_set(iatom)%r, cell)
      END DO

      ALLOCATE (level_list(ntasks), iatom_list(ntasks), jatom_list(ntasks))
      ALLOCATE (iset_list(ntasks), jset_list(ntasks), ipgf_list(ntasks), jpgf_list(ntasks))
      ALLOCATE (border_mask_list(ntasks), block_num_list(ntasks))
      ALLOCATE (radius_list(ntasks), rab_list(3, ntasks))

      DO itask = 1, ntasks
         level_list(itask) = tasks(itask)%grid_level
         iatom_list(itask) = tasks(itask)%iatom
         jatom_list(itask) = tasks(itask)%jatom
         iset_list(itask) = tasks(itask)%iset
         jset_list(itask) = tasks(itask)%jset
         ipgf_list(itask) = tasks(itask)%ipgf
         jpgf_list(itask) = tasks(itask)%jpgf
         IF (tasks(itask)%dist_type == 2) THEN
            border_mask_list(itask) = IAND(63, NOT(tasks(itask)%subpatch_pattern))  ! invert last 6 bits
         ELSE
            border_mask_list(itask) = 0 ! no masking
         END IF
         block_num_list(itask) = tasks(itask)%pair_index  ! change of nomenclature pair_index -> block_num
         radius_list(itask) = tasks(itask)%radius
         rab_list(:, itask) = tasks(itask)%rab(:)
      END DO

      CALL grid_create_task_list(ntasks=ntasks, &
                                 natoms=natoms, &
                                 nkinds=nkinds, &
                                 nblocks=SIZE(task_list%pair_offsets_recv), &
                                 block_offsets=task_list%pair_offsets_recv, &
                                 atom_positions=atom_positions, &
                                 atom_kinds=atom_kinds, &
                                 basis_sets=task_list%grid_basis_sets, &
                                 level_list=level_list, &
                                 iatom_list=iatom_list, &
                                 jatom_list=jatom_list, &
                                 iset_list=iset_list, &
                                 jset_list=jset_list, &
                                 ipgf_list=ipgf_list, &
                                 jpgf_list=jpgf_list, &
                                 border_mask_list=border_mask_list, &
                                 block_num_list=block_num_list, &
                                 radius_list=radius_list, &
                                 rab_list=rab_list, &
                                 rs_grids=rs_grids, &
                                 task_list=task_list%grid_task_list)

      CALL offload_create_buffer(task_list%buffer_size_recv, task_list%pab_buffer)
      CALL offload_create_buffer(task_list%buffer_size_recv, task_list%hab_buffer)

   END SUBROUTINE create_grid_task_list

! **************************************************************************************************
!> \brief ...
!> \param tasks ...
!> \param ntasks ...
!> \param curr_tasks ...
!> \param rs_descs ...
!> \param dft_control ...
!> \param cube_info ...
!> \param gridlevel_info ...
!> \param cindex ...
!> \param iatom ...
!> \param jatom ...
!> \param rpgfa ...
!> \param rpgfb ...
!> \param zeta ...
!> \param zetb ...
!> \param kind_radius_b ...
!> \param set_radius_a ...
!> \param set_radius_b ...
!> \param ra ...
!> \param rab ...
!> \param la_max ...
!> \param la_min ...
!> \param lb_max ...
!> \param lb_min ...
!> \param npgfa ...
!> \param npgfb ...
!> \param nseta ...
!> \param nsetb ...
!> \par History
!>      Joost VandeVondele: 10.2008 refactored
! **************************************************************************************************
   SUBROUTINE task_list_inner_loop(tasks, ntasks, curr_tasks, rs_descs, dft_control, &
                                   cube_info, gridlevel_info, cindex, &
                                   iatom, jatom, rpgfa, rpgfb, zeta, zetb, kind_radius_b, set_radius_a, set_radius_b, ra, rab, &
                                   la_max, la_min, lb_max, lb_min, npgfa, npgfb, nseta, nsetb)

      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER                                            :: ntasks, curr_tasks
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      TYPE(dft_control_type), POINTER                    :: dft_control
      TYPE(cube_info_type), DIMENSION(:), POINTER        :: cube_info
      TYPE(gridlevel_info_type), POINTER                 :: gridlevel_info
      INTEGER                                            :: cindex, iatom, jatom
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: rpgfa, rpgfb, zeta, zetb
      REAL(KIND=dp)                                      :: kind_radius_b
      REAL(KIND=dp), DIMENSION(:), POINTER               :: set_radius_a, set_radius_b
      REAL(KIND=dp), DIMENSION(3)                        :: ra, rab
      INTEGER, DIMENSION(:), POINTER                     :: la_max, la_min, lb_max, lb_min, npgfa, &
                                                            npgfb
      INTEGER                                            :: nseta, nsetb

      INTEGER                                            :: cube_center(3), igrid_level, ipgf, iset, &
                                                            jpgf, jset, lb_cube(3), ub_cube(3)
      REAL(KIND=dp)                                      :: dab, rab2, radius, zetp

      rab2 = rab(1)*rab(1) + rab(2)*rab(2) + rab(3)*rab(3)
      dab = SQRT(rab2)

      loop_iset: DO iset = 1, nseta

         IF (set_radius_a(iset) + kind_radius_b < dab) CYCLE loop_iset

         loop_jset: DO jset = 1, nsetb

            IF (set_radius_a(iset) + set_radius_b(jset) < dab) CYCLE loop_jset

            loop_ipgf: DO ipgf = 1, npgfa(iset)

               IF (rpgfa(ipgf, iset) + set_radius_b(jset) < dab) CYCLE loop_ipgf

               loop_jpgf: DO jpgf = 1, npgfb(jset)

                  IF (rpgfa(ipgf, iset) + rpgfb(jpgf, jset) < dab) CYCLE loop_jpgf

                  zetp = zeta(ipgf, iset) + zetb(jpgf, jset)
                  igrid_level = gaussian_gridlevel(gridlevel_info, zetp)

                  CALL compute_pgf_properties(cube_center, lb_cube, ub_cube, radius, &
                                              rs_descs(igrid_level)%rs_desc, cube_info(igrid_level), &
                                              la_max(iset), zeta(ipgf, iset), la_min(iset), &
                                              lb_max(jset), zetb(jpgf, jset), lb_min(jset), &
                                              ra, rab, rab2, dft_control%qs_control%eps_rho_rspace)

                  CALL pgf_to_tasks(tasks, ntasks, curr_tasks, &
                                    rab, cindex, iatom, jatom, iset, jset, ipgf, jpgf, &
                                    la_max(iset), lb_max(jset), rs_descs(igrid_level)%rs_desc, &
                                    igrid_level, gridlevel_info%ngrid_levels, cube_center, &
                                    lb_cube, ub_cube, radius)

               END DO loop_jpgf

            END DO loop_ipgf

         END DO loop_jset

      END DO loop_iset

   END SUBROUTINE task_list_inner_loop

! **************************************************************************************************
!> \brief combines the calculation of several basic properties of a given pgf:
!>  its center, the bounding cube, the radius, the cost,
!>  tries to predict the time needed for processing this task
!>      in this way an improved load balance might be obtained
!> \param cube_center ...
!> \param lb_cube ...
!> \param ub_cube ...
!> \param radius ...
!> \param rs_desc ...
!> \param cube_info ...
!> \param la_max ...
!> \param zeta ...
!> \param la_min ...
!> \param lb_max ...
!> \param zetb ...
!> \param lb_min ...
!> \param ra ...
!> \param rab ...
!> \param rab2 ...
!> \param eps ...
!> \par History
!>      10.2008 refactored [Joost VandeVondele]
!> \note
!>      -) this requires the radius to be computed in the same way as
!>      collocate_pgf_product, we should factor that part into a subroutine
!>      -) we're assuming that integrate_pgf and collocate_pgf are the same cost for load balancing
!>         this is more or less true for map_consistent
!>      -) in principle, the computed radius could be recycled in integrate_pgf/collocate_pgf if it is certainly
!>         the same, this could lead to a small speedup
!>      -) the cost function is a fit through the median cost of mapping a pgf with a given l and a given radius (in grid points)
!>         fitting the measured data on an opteron/g95 using the expression
!>         a*(l+b)(r+c)**3+d which is based on the innerloop of the collocating routines
! **************************************************************************************************
   SUBROUTINE compute_pgf_properties(cube_center, lb_cube, ub_cube, radius, &
                                     rs_desc, cube_info, la_max, zeta, la_min, lb_max, zetb, lb_min, ra, rab, rab2, eps)

      INTEGER, DIMENSION(3), INTENT(OUT)                 :: cube_center, lb_cube, ub_cube
      REAL(KIND=dp), INTENT(OUT)                         :: radius
      TYPE(realspace_grid_desc_type), POINTER            :: rs_desc
      TYPE(cube_info_type), INTENT(IN)                   :: cube_info
      INTEGER, INTENT(IN)                                :: la_max
      REAL(KIND=dp), INTENT(IN)                          :: zeta
      INTEGER, INTENT(IN)                                :: la_min, lb_max
      REAL(KIND=dp), INTENT(IN)                          :: zetb
      INTEGER, INTENT(IN)                                :: lb_min
      REAL(KIND=dp), INTENT(IN)                          :: ra(3), rab(3), rab2, eps

      INTEGER                                            :: extent(3)
      INTEGER, DIMENSION(:), POINTER                     :: sphere_bounds
      REAL(KIND=dp)                                      :: cutoff, f, prefactor, rb(3), zetp
      REAL(KIND=dp), DIMENSION(3)                        :: rp

! the radius for this task

      zetp = zeta + zetb
      rp(:) = ra(:) + zetb/zetp*rab(:)
      rb(:) = ra(:) + rab(:)
      cutoff = 1.0_dp
      f = zetb/zetp
      prefactor = EXP(-zeta*f*rab2)
      radius = exp_radius_very_extended(la_min, la_max, lb_min, lb_max, ra=ra, rb=rb, rp=rp, &
                                        zetp=zetp, eps=eps, prefactor=prefactor, cutoff=cutoff)

      CALL compute_cube_center(cube_center, rs_desc, zeta, zetb, ra, rab)
      ! compute cube_center, the center of the gaussian product to map (folded to within the unit cell)
      cube_center(:) = MODULO(cube_center(:), rs_desc%npts(:))
      cube_center(:) = cube_center(:) + rs_desc%lb(:)

      IF (rs_desc%orthorhombic) THEN
         CALL return_cube(cube_info, radius, lb_cube, ub_cube, sphere_bounds)
      ELSE
         CALL return_cube_nonortho(cube_info, radius, lb_cube, ub_cube, rp)
         !? unclear if extent is computed correctly.
         extent(:) = ub_cube(:) - lb_cube(:)
         lb_cube(:) = -extent(:)/2 - 1
         ub_cube(:) = extent(:)/2
      END IF

   END SUBROUTINE compute_pgf_properties
! **************************************************************************************************
!> \brief predicts the cost of a task in kcycles for a given task
!>        the model is based on a fit of actual data, and might need updating
!>        as collocate_pgf_product changes (or CPUs/compilers change)
!>        maybe some dynamic approach, improving the cost model on the fly could
!>        work as well
!>        the cost model does not yet take into account the fraction of space
!>        that is mapped locally for a given cube and rs_grid (generalised tasks)
!> \param lb_cube ...
!> \param ub_cube ...
!> \param fraction ...
!> \param lmax ...
!> \param is_ortho ...
!> \return ...
! **************************************************************************************************
   INTEGER FUNCTION cost_model(lb_cube, ub_cube, fraction, lmax, is_ortho)
      INTEGER, DIMENSION(3), INTENT(IN)                  :: lb_cube, ub_cube
      REAL(KIND=dp), INTENT(IN)                          :: fraction
      INTEGER                                            :: lmax
      LOGICAL                                            :: is_ortho

      INTEGER                                            :: cmax
      REAL(KIND=dp)                                      :: v1, v2, v3, v4, v5

      cmax = MAXVAL(((ub_cube - lb_cube) + 1)/2)

      IF (is_ortho) THEN
         v1 = 1.504760E+00_dp
         v2 = 3.126770E+00_dp
         v3 = 5.074106E+00_dp
         v4 = 1.091568E+00_dp
         v5 = 1.070187E+00_dp
      ELSE
         v1 = 7.831105E+00_dp
         v2 = 2.675174E+00_dp
         v3 = 7.546553E+00_dp
         v4 = 6.122446E-01_dp
         v5 = 3.886382E+00_dp
      END IF
      cost_model = CEILING(((lmax + v1)*(cmax + v2)**3*v3*fraction + v4 + v5*lmax**7)/1000.0_dp)

   END FUNCTION cost_model
! **************************************************************************************************
!> \brief pgf_to_tasks converts a given pgf to one or more tasks, in particular
!>        this determines by which CPUs a given pgf gets collocated
!>        the format of the task array is as follows
!>        tasks(1,i) := destination
!>        tasks(2,i) := source
!>        tasks(3,i) := compressed type (iatom, jatom, ....)
!>        tasks(4,i) := type (0: replicated, 1: distributed local, 2: distributed generalised)
!>        tasks(5,i) := cost
!>        tasks(6,i) := alternate destination code (0 if none available)
!>
!> \param tasks ...
!> \param ntasks ...
!> \param curr_tasks ...
!> \param rab ...
!> \param cindex ...
!> \param iatom ...
!> \param jatom ...
!> \param iset ...
!> \param jset ...
!> \param ipgf ...
!> \param jpgf ...
!> \param la_max ...
!> \param lb_max ...
!> \param rs_desc ...
!> \param igrid_level ...
!> \param n_levels ...
!> \param cube_center ...
!> \param lb_cube ...
!> \param ub_cube ...
!> \par History
!>      10.2008 Refactored based on earlier routines by MattW [Joost VandeVondele]
! **************************************************************************************************
   SUBROUTINE pgf_to_tasks(tasks, ntasks, curr_tasks, &
                           rab, cindex, iatom, jatom, iset, jset, ipgf, jpgf, &
                           la_max, lb_max, rs_desc, igrid_level, n_levels, &
                           cube_center, lb_cube, ub_cube, radius)

      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER, INTENT(INOUT)                             :: ntasks, curr_tasks
      REAL(KIND=dp), DIMENSION(3), INTENT(IN)            :: rab
      INTEGER, INTENT(IN)                                :: cindex, iatom, jatom, iset, jset, ipgf, &
                                                            jpgf, la_max, lb_max
      TYPE(realspace_grid_desc_type), POINTER            :: rs_desc
      INTEGER, INTENT(IN)                                :: igrid_level, n_levels
      INTEGER, DIMENSION(3), INTENT(IN)                  :: cube_center, lb_cube, ub_cube
      REAL(KIND=dp), INTENT(IN)                          :: radius

      INTEGER, PARAMETER                                 :: add_tasks = 1000
      REAL(kind=dp), PARAMETER                           :: mult_tasks = 2.0_dp

      INTEGER                                            :: added_tasks, cost, j, lmax
      LOGICAL                                            :: is_ortho
      REAL(KIND=dp)                                      :: tfraction

!$OMP SINGLE
      ntasks = ntasks + 1
      IF (ntasks > curr_tasks) THEN
         curr_tasks = INT((curr_tasks + add_tasks)*mult_tasks)
         CALL reallocate_tasks(tasks, curr_tasks)
      END IF
!$OMP END SINGLE

      IF (rs_desc%distributed) THEN

         ! finds the node(s) that need to process this task
         ! on exit tasks(:)%dist_type is 1 for distributed tasks and 2 for generalised tasks
         CALL rs_find_node(rs_desc, igrid_level, n_levels, cube_center, &
                           ntasks=ntasks, tasks=tasks, lb_cube=lb_cube, ub_cube=ub_cube, added_tasks=added_tasks)

      ELSE
         tasks(ntasks)%destination = encode_rank(rs_desc%my_pos, igrid_level, n_levels)
         tasks(ntasks)%dist_type = 0
         tasks(ntasks)%subpatch_pattern = 0
         added_tasks = 1
      END IF

      lmax = la_max + lb_max
      is_ortho = (tasks(ntasks)%dist_type == 0 .OR. tasks(ntasks)%dist_type == 1) .AND. rs_desc%orthorhombic
      ! we assume the load is shared equally between processes dealing with a generalised Gaussian.
      ! this could be refined in the future
      tfraction = 1.0_dp/added_tasks

      cost = cost_model(lb_cube, ub_cube, tfraction, lmax, is_ortho)

      DO j = 1, added_tasks
         tasks(ntasks - added_tasks + j)%source = encode_rank(rs_desc%my_pos, igrid_level, n_levels)
         tasks(ntasks - added_tasks + j)%cost = cost
         tasks(ntasks - added_tasks + j)%grid_level = igrid_level
         tasks(ntasks - added_tasks + j)%image = cindex
         tasks(ntasks - added_tasks + j)%iatom = iatom
         tasks(ntasks - added_tasks + j)%jatom = jatom
         tasks(ntasks - added_tasks + j)%iset = iset
         tasks(ntasks - added_tasks + j)%jset = jset
         tasks(ntasks - added_tasks + j)%ipgf = ipgf
         tasks(ntasks - added_tasks + j)%jpgf = jpgf
         tasks(ntasks - added_tasks + j)%rab = rab
         tasks(ntasks - added_tasks + j)%radius = radius
      END DO

   END SUBROUTINE pgf_to_tasks

! **************************************************************************************************
!> \brief performs load balancing of the tasks on the distributed grids
!> \param tasks ...
!> \param ntasks ...
!> \param rs_descs ...
!> \param grid_level ...
!> \param natoms ...
!> \par History
!>      created 2008-10-03 [Joost VandeVondele]
! **************************************************************************************************
   SUBROUTINE load_balance_distributed(tasks, ntasks, rs_descs, grid_level, natoms)

      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER                                            :: ntasks
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER                                            :: grid_level, natoms

      CHARACTER(LEN=*), PARAMETER :: routineN = 'load_balance_distributed'

      INTEGER                                            :: handle
      INTEGER, DIMENSION(:, :, :), POINTER               :: list

      CALL timeset(routineN, handle)

      NULLIFY (list)
      ! here we create for each cpu (0:ncpu-1) a list of possible destinations.
      ! if a destination would not be in this list, it is a bug
      CALL create_destination_list(list, rs_descs, grid_level)

      ! now, walk over the tasks, filling in the loads of each destination
      CALL compute_load_list(list, rs_descs, grid_level, tasks, ntasks, natoms, create_list=.TRUE.)

      ! optimize loads & fluxes
      CALL optimize_load_list(list, rs_descs(1)%rs_desc%group, rs_descs(1)%rs_desc%my_pos)

      ! now, walk over the tasks, using the list to set the destinations
      CALL compute_load_list(list, rs_descs, grid_level, tasks, ntasks, natoms, create_list=.FALSE.)

      DEALLOCATE (list)

      CALL timestop(handle)

   END SUBROUTINE load_balance_distributed

! **************************************************************************************************
!> \brief this serial routine adjusts the fluxes in the global list
!>
!> \param list_global ...
!> \par History
!>      created 2008-10-06 [Joost VandeVondele]
! **************************************************************************************************
   SUBROUTINE balance_global_list(list_global)
      INTEGER, DIMENSION(:, :, 0:)                       :: list_global

      CHARACTER(LEN=*), PARAMETER :: routineN = 'balance_global_list'
      INTEGER, PARAMETER                                 :: Max_Iter = 100
      REAL(KIND=dp), PARAMETER                           :: Tolerance_factor = 0.005_dp

      INTEGER                                            :: dest, handle, icpu, idest, iflux, &
                                                            ilocal, k, maxdest, Ncpu, Nflux
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: flux_connections
      LOGICAL                                            :: solution_optimal
      REAL(KIND=dp)                                      :: average, load_shift, max_load_shift, &
                                                            tolerance
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:)           :: load, optimized_flux, optimized_load
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:, :)        :: flux_limits

      CALL timeset(routineN, handle)

      Ncpu = SIZE(list_global, 3)
      maxdest = SIZE(list_global, 2)
      ALLOCATE (load(0:Ncpu - 1))
      load = 0.0_dp
      ALLOCATE (optimized_load(0:Ncpu - 1))

      ! figure out the number of fluxes
      ! we assume that the global_list is symmetric
      Nflux = 0
      DO icpu = 0, ncpu - 1
         DO idest = 1, maxdest
            dest = list_global(1, idest, icpu)
            IF (dest < ncpu .AND. dest > icpu) Nflux = Nflux + 1
         END DO
      END DO
      ALLOCATE (optimized_flux(Nflux))
      ALLOCATE (flux_limits(2, Nflux))
      ALLOCATE (flux_connections(2, Nflux))

      ! reorder data
      flux_limits = 0
      Nflux = 0
      DO icpu = 0, ncpu - 1
         load(icpu) = SUM(list_global(2, :, icpu))
         DO idest = 1, maxdest
            dest = list_global(1, idest, icpu)
            IF (dest < ncpu) THEN
               IF (dest /= icpu) THEN
                  IF (dest > icpu) THEN
                     Nflux = Nflux + 1
                     flux_limits(2, Nflux) = list_global(2, idest, icpu)
                     flux_connections(1, Nflux) = icpu
                     flux_connections(2, Nflux) = dest
                  ELSE
                     DO iflux = 1, Nflux
                        IF (flux_connections(1, iflux) == dest .AND. flux_connections(2, iflux) == icpu) THEN
                           flux_limits(1, iflux) = -list_global(2, idest, icpu)
                           EXIT
                        END IF
                     END DO
                  END IF
               END IF
            END IF
         END DO
      END DO

      solution_optimal = .FALSE.
      optimized_flux = 0.0_dp

      ! an iterative solver, if iterated till convergence the maximum load is minimal
      ! we terminate before things are fully converged, since this does show up in the timings
      ! once the largest shift becomes less than a small fraction of the average load, we're done
      ! we're perfectly happy if the load balance is within 1 percent or so
      ! the maximum load normally converges even faster
      average = SUM(load)/SIZE(load)
      tolerance = Tolerance_factor*average

      optimized_load(:) = load
      DO k = 1, Max_iter
         max_load_shift = 0.0_dp
         DO iflux = 1, Nflux
            load_shift = (optimized_load(flux_connections(1, iflux)) - optimized_load(flux_connections(2, iflux)))/2
            load_shift = MAX(flux_limits(1, iflux) - optimized_flux(iflux), load_shift)
            load_shift = MIN(flux_limits(2, iflux) - optimized_flux(iflux), load_shift)
            max_load_shift = MAX(ABS(load_shift), max_load_shift)
            optimized_load(flux_connections(1, iflux)) = optimized_load(flux_connections(1, iflux)) - load_shift
            optimized_load(flux_connections(2, iflux)) = optimized_load(flux_connections(2, iflux)) + load_shift
            optimized_flux(iflux) = optimized_flux(iflux) + load_shift
         END DO
         IF (max_load_shift < tolerance) THEN
            solution_optimal = .TRUE.
            EXIT
         END IF
      END DO

      ! now adjust the load list to reflect the optimized fluxes
      ! reorder data
      Nflux = 0
      DO icpu = 0, ncpu - 1
         DO idest = 1, maxdest
            IF (list_global(1, idest, icpu) == icpu) ilocal = idest
         END DO
         DO idest = 1, maxdest
            dest = list_global(1, idest, icpu)
            IF (dest < ncpu) THEN
               IF (dest /= icpu) THEN
                  IF (dest > icpu) THEN
                     Nflux = Nflux + 1
                     IF (optimized_flux(Nflux) > 0) THEN
                        list_global(2, ilocal, icpu) = list_global(2, ilocal, icpu) + &
                                                       list_global(2, idest, icpu) - NINT(optimized_flux(Nflux))
                        list_global(2, idest, icpu) = NINT(optimized_flux(Nflux))
                     ELSE
                        list_global(2, ilocal, icpu) = list_global(2, ilocal, icpu) + &
                                                       list_global(2, idest, icpu)
                        list_global(2, idest, icpu) = 0
                     END IF
                  ELSE
                     DO iflux = 1, Nflux
                        IF (flux_connections(1, iflux) == dest .AND. flux_connections(2, iflux) == icpu) THEN
                           IF (optimized_flux(iflux) > 0) THEN
                              list_global(2, ilocal, icpu) = list_global(2, ilocal, icpu) + &
                                                             list_global(2, idest, icpu)
                              list_global(2, idest, icpu) = 0
                           ELSE
                              list_global(2, ilocal, icpu) = list_global(2, ilocal, icpu) + &
                                                             list_global(2, idest, icpu) + NINT(optimized_flux(iflux))
                              list_global(2, idest, icpu) = -NINT(optimized_flux(iflux))
                           END IF
                           EXIT
                        END IF
                     END DO
                  END IF
               END IF
            END IF
         END DO
      END DO

      CALL timestop(handle)

   END SUBROUTINE balance_global_list

! **************************************************************************************************
!> \brief this routine gets back optimized loads for all destinations
!>
!> \param list ...
!> \param group ...
!> \param my_pos ...
!> \par History
!>      created 2008-10-06 [Joost VandeVondele]
!>      Modified 2016-01   [EPCC] Reduce memory requirements on P processes
!>                                from O(P^2) to O(P)
! **************************************************************************************************
   SUBROUTINE optimize_load_list(list, group, my_pos)
      INTEGER, DIMENSION(:, :, 0:)                       :: list
      TYPE(mp_comm_type), INTENT(IN) :: group
      INTEGER, INTENT(IN)                                :: my_pos

      CHARACTER(LEN=*), PARAMETER :: routineN = 'optimize_load_list'
      INTEGER, PARAMETER                                 :: rank_of_root = 0

      INTEGER                                            :: handle, icpu, idest, maxdest, ncpu
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: load_all
      INTEGER, ALLOCATABLE, DIMENSION(:, :)              :: load_partial
      INTEGER, ALLOCATABLE, DIMENSION(:, :, :)           :: list_global

      CALL timeset(routineN, handle)

      ncpu = SIZE(list, 3)
      maxdest = SIZE(list, 2)

      !find total workload ...
      ALLOCATE (load_all(maxdest*ncpu))
      load_all(:) = RESHAPE(list(2, :, :), [maxdest*ncpu])
      CALL group%sum(load_all(:), rank_of_root)

      ! ... and optimise the work per process
      ALLOCATE (list_global(2, maxdest, ncpu))
      IF (rank_of_root == my_pos) THEN
         list_global(1, :, :) = list(1, :, :)
         list_global(2, :, :) = RESHAPE(load_all, [maxdest, ncpu])
         CALL balance_global_list(list_global)
      END IF
      CALL group%bcast(list_global, rank_of_root)

      !figure out how much can be sent to other processes
      ALLOCATE (load_partial(maxdest, ncpu))
      ! send 'load_all', which is a copy of 'list' (but without leading dimension/stride)
      CALL group%sum_partial(RESHAPE(load_all, [maxdest, ncpu]), load_partial(:, :))

      DO icpu = 1, ncpu
         DO idest = 1, maxdest

            !need to deduct 1 because `list' was passed in to this routine as being indexed from zero
            IF (load_partial(idest, icpu) > list_global(2, idest, icpu)) THEN
               IF (load_partial(idest, icpu) - list(2, idest, icpu - 1) < list_global(2, idest, icpu)) THEN
                  list(2, idest, icpu - 1) = list_global(2, idest, icpu) &
                                             - (load_partial(idest, icpu) - list(2, idest, icpu - 1))
               ELSE
                  list(2, idest, icpu - 1) = 0
               END IF
            END IF

         END DO
      END DO

      !clean up before leaving
      DEALLOCATE (load_all)
      DEALLOCATE (list_global)
      DEALLOCATE (load_partial)

      CALL timestop(handle)
   END SUBROUTINE optimize_load_list

! **************************************************************************************************
!> \brief fill the load list with values derived from the tasks array
!>        from the alternate locations, we select the alternate location that
!>        can be used without increasing the number of matrix blocks needed to
!>        distribute.
!>        Replicated tasks are not yet considered
!>
!> \param list ...
!> \param rs_descs ...
!> \param grid_level ...
!> \param tasks ...
!> \param ntasks ...
!> \param natoms ...
!> \param create_list ...
!> \par History
!>      created 2008-10-06 [Joost VandeVondele]
! **************************************************************************************************
   SUBROUTINE compute_load_list(list, rs_descs, grid_level, tasks, ntasks, natoms, create_list)
      INTEGER, DIMENSION(:, :, 0:)                       :: list
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER                                            :: grid_level
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER                                            :: ntasks, natoms
      LOGICAL                                            :: create_list

      CHARACTER(LEN=*), PARAMETER :: routineN = 'compute_load_list'

      INTEGER :: cost, dest, handle, i, iatom, ilevel, img, img_old, iopt, ipgf, iset, itask, &
                 itask_start, itask_stop, jatom, jpgf, jset, li, maxdest, ncpu, ndest_pair, nopt, nshort, &
                 rank
      INTEGER(KIND=int_8)                                :: bit_pattern, ipair, ipair_old, natom8
      INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: loads
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: all_dests, index
      INTEGER, DIMENSION(6)                              :: options

      CALL timeset(routineN, handle)

      ALLOCATE (loads(0:rs_descs(grid_level)%rs_desc%group_size - 1))
      CALL get_current_loads(loads, rs_descs, grid_level, ntasks, tasks, use_reordered_ranks=.FALSE.)

      maxdest = SIZE(list, 2)
      ncpu = SIZE(list, 3)
      natom8 = natoms

      ! first find the tasks that deal with the same atom pair
      itask_stop = 0
      ipair_old = HUGE(ipair_old)
      img_old = -1
      ALLOCATE (all_dests(0))
      ALLOCATE (INDEX(0))

      DO

         ! first find the range of tasks that deal with the same atom pair
         itask_start = itask_stop + 1
         itask_stop = itask_start
         IF (itask_stop > ntasks) EXIT
         ilevel = tasks(itask_stop)%grid_level
         img_old = tasks(itask_stop)%image
         iatom = tasks(itask_stop)%iatom
         jatom = tasks(itask_stop)%jatom
         iset = tasks(itask_stop)%iset
         jset = tasks(itask_stop)%jset
         ipgf = tasks(itask_stop)%ipgf
         jpgf = tasks(itask_stop)%jpgf

         ipair_old = (iatom - 1)*natom8 + (jatom - 1)
         DO
            IF (itask_stop + 1 > ntasks) EXIT
            ilevel = tasks(itask_stop + 1)%grid_level
            img = tasks(itask_stop + 1)%image
            iatom = tasks(itask_stop + 1)%iatom
            jatom = tasks(itask_stop + 1)%jatom
            iset = tasks(itask_stop + 1)%iset
            jset = tasks(itask_stop + 1)%jset
            ipgf = tasks(itask_stop + 1)%ipgf
            jpgf = tasks(itask_stop + 1)%jpgf

            ipair = (iatom - 1)*natom8 + (jatom - 1)
            IF (ipair == ipair_old .AND. img == img_old) THEN
               itask_stop = itask_stop + 1
            ELSE
               EXIT
            END IF
         END DO
         ipair = ipair_old
         nshort = itask_stop - itask_start + 1

         ! find the unique list of destinations on this grid level only
         DEALLOCATE (all_dests)
         ALLOCATE (all_dests(nshort))
         DEALLOCATE (index)
         ALLOCATE (INDEX(nshort))
         DO i = 1, nshort
            ilevel = tasks(itask_start + i - 1)%grid_level
            img = tasks(itask_start + i - 1)%image
            iatom = tasks(itask_start + i - 1)%iatom
            jatom = tasks(itask_start + i - 1)%jatom
            iset = tasks(itask_start + i - 1)%iset
            jset = tasks(itask_start + i - 1)%jset
            ipgf = tasks(itask_start + i - 1)%ipgf
            jpgf = tasks(itask_start + i - 1)%jpgf

            IF (ilevel == grid_level) THEN
               all_dests(i) = decode_rank(tasks(itask_start + i - 1)%destination, SIZE(rs_descs))
            ELSE
               all_dests(i) = HUGE(all_dests(i))
            END IF
         END DO
         CALL sort(all_dests, nshort, index)
         ndest_pair = 1
         DO i = 2, nshort
            IF ((all_dests(ndest_pair) /= all_dests(i)) .AND. (all_dests(i) /= HUGE(all_dests(i)))) THEN
               ndest_pair = ndest_pair + 1
               all_dests(ndest_pair) = all_dests(i)
            END IF
         END DO

         DO itask = itask_start, itask_stop

            dest = decode_rank(tasks(itask)%destination, SIZE(rs_descs)) ! notice that dest can be changed
            ilevel = tasks(itask)%grid_level
            img = tasks(itask)%image
            iatom = tasks(itask)%iatom
            jatom = tasks(itask)%jatom
            iset = tasks(itask)%iset
            jset = tasks(itask)%jset
            ipgf = tasks(itask)%ipgf
            jpgf = tasks(itask)%jpgf

            ! Only proceed with tasks which are on this grid level
            IF (ilevel /= grid_level) CYCLE
            ipair = (iatom - 1)*natom8 + (jatom - 1)
            cost = INT(tasks(itask)%cost)

            SELECT CASE (tasks(itask)%dist_type)
            CASE (1)
               bit_pattern = tasks(itask)%subpatch_pattern
               nopt = 0
               IF (BTEST(bit_pattern, 0)) THEN
                  rank = rs_grid_locate_rank(rs_descs(ilevel)%rs_desc, dest, [-1, 0, 0])
                  IF (ANY(all_dests(1:ndest_pair) == rank)) THEN
                     nopt = nopt + 1
                     options(nopt) = rank
                  END IF
               END IF
               IF (BTEST(bit_pattern, 1)) THEN
                  rank = rs_grid_locate_rank(rs_descs(ilevel)%rs_desc, dest, [+1, 0, 0])
                  IF (ANY(all_dests(1:ndest_pair) == rank)) THEN
                     nopt = nopt + 1
                     options(nopt) = rank
                  END IF
               END IF
               IF (BTEST(bit_pattern, 2)) THEN
                  rank = rs_grid_locate_rank(rs_descs(ilevel)%rs_desc, dest, [0, -1, 0])
                  IF (ANY(all_dests(1:ndest_pair) == rank)) THEN
                     nopt = nopt + 1
                     options(nopt) = rank
                  END IF
               END IF
               IF (BTEST(bit_pattern, 3)) THEN
                  rank = rs_grid_locate_rank(rs_descs(ilevel)%rs_desc, dest, [0, +1, 0])
                  IF (ANY(all_dests(1:ndest_pair) == rank)) THEN
                     nopt = nopt + 1
                     options(nopt) = rank
                  END IF
               END IF
               IF (BTEST(bit_pattern, 4)) THEN
                  rank = rs_grid_locate_rank(rs_descs(ilevel)%rs_desc, dest, [0, 0, -1])
                  IF (ANY(all_dests(1:ndest_pair) == rank)) THEN
                     nopt = nopt + 1
                     options(nopt) = rank
                  END IF
               END IF
               IF (BTEST(bit_pattern, 5)) THEN
                  rank = rs_grid_locate_rank(rs_descs(ilevel)%rs_desc, dest, [0, 0, +1])
                  IF (ANY(all_dests(1:ndest_pair) == rank)) THEN
                     nopt = nopt + 1
                     options(nopt) = rank
                  END IF
               END IF
               IF (nopt > 0) THEN
                  ! set it to the rank with the lowest load
                  rank = options(1)
                  DO iopt = 2, nopt
                     IF (loads(rank) > loads(options(iopt))) rank = options(iopt)
                  END DO
               ELSE
                  rank = dest
               END IF
               li = list_index(list, rank, dest)
               IF (create_list) THEN
                  list(2, li, dest) = list(2, li, dest) + cost
               ELSE
                  IF (list(1, li, dest) == dest) THEN
                     tasks(itask)%destination = encode_rank(dest, ilevel, SIZE(rs_descs))
                  ELSE
                     IF (list(2, li, dest) >= cost) THEN
                        list(2, li, dest) = list(2, li, dest) - cost
                        tasks(itask)%destination = encode_rank(list(1, li, dest), ilevel, SIZE(rs_descs))
                     ELSE
                        tasks(itask)%destination = encode_rank(dest, ilevel, SIZE(rs_descs))
                     END IF
                  END IF
               END IF
            CASE (2) ! generalised
               li = list_index(list, dest, dest)
               IF (create_list) THEN
                  list(2, li, dest) = list(2, li, dest) + cost
               ELSE
                  IF (list(1, li, dest) == dest) THEN
                     tasks(itask)%destination = encode_rank(dest, ilevel, SIZE(rs_descs))
                  ELSE
                     IF (list(2, li, dest) >= cost) THEN
                        list(2, li, dest) = list(2, li, dest) - cost
                        tasks(itask)%destination = encode_rank(list(1, li, dest), ilevel, SIZE(rs_descs))
                     ELSE
                        tasks(itask)%destination = encode_rank(dest, ilevel, SIZE(rs_descs))
                     END IF
                  END IF
               END IF
            CASE DEFAULT
               CPABORT("")
            END SELECT

         END DO

      END DO

      CALL timestop(handle)

   END SUBROUTINE compute_load_list
! **************************************************************************************************
!> \brief small helper function to return the proper index in the list array
!>
!> \param list ...
!> \param rank ...
!> \param dest ...
!> \return ...
!> \par History
!>      created 2008-10-06 [Joost VandeVondele]
! **************************************************************************************************
   INTEGER FUNCTION list_index(list, rank, dest)
      INTEGER, DIMENSION(:, :, 0:), INTENT(IN)           :: list
      INTEGER, INTENT(IN)                                :: rank, dest

      list_index = 1
      DO
         IF (list(1, list_index, dest) == rank) EXIT
         list_index = list_index + 1
      END DO
   END FUNCTION list_index
! **************************************************************************************************
!> \brief create a list with possible destinations (i.e. the central cpu and neighbors) for each cpu
!>        note that we allocate it with an additional field to store the load of this destination
!>
!> \param list ...
!> \param rs_descs ...
!> \param grid_level ...
!> \par History
!>      created 2008-10-06 [Joost VandeVondele]
! **************************************************************************************************
   SUBROUTINE create_destination_list(list, rs_descs, grid_level)
      INTEGER, DIMENSION(:, :, :), POINTER               :: list
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER, INTENT(IN)                                :: grid_level

      CHARACTER(LEN=*), PARAMETER :: routineN = 'create_destination_list'

      INTEGER                                            :: handle, i, icpu, j, maxcount, ncpu, &
                                                            ultimate_max
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: index, sublist

      CALL timeset(routineN, handle)

      CPASSERT(.NOT. ASSOCIATED(list))
      ncpu = rs_descs(grid_level)%rs_desc%group_size
      ultimate_max = 7

      ALLOCATE (list(2, ultimate_max, 0:ncpu - 1))

      ALLOCATE (INDEX(ultimate_max))
      ALLOCATE (sublist(ultimate_max))
      sublist = HUGE(sublist)

      maxcount = 1
      DO icpu = 0, ncpu - 1
         sublist(1) = icpu
         sublist(2) = rs_grid_locate_rank(rs_descs(grid_level)%rs_desc, icpu, [-1, 0, 0])
         sublist(3) = rs_grid_locate_rank(rs_descs(grid_level)%rs_desc, icpu, [+1, 0, 0])
         sublist(4) = rs_grid_locate_rank(rs_descs(grid_level)%rs_desc, icpu, [0, -1, 0])
         sublist(5) = rs_grid_locate_rank(rs_descs(grid_level)%rs_desc, icpu, [0, +1, 0])
         sublist(6) = rs_grid_locate_rank(rs_descs(grid_level)%rs_desc, icpu, [0, 0, -1])
         sublist(7) = rs_grid_locate_rank(rs_descs(grid_level)%rs_desc, icpu, [0, 0, +1])
         ! only retain unique values of the destination
         CALL sort(sublist, ultimate_max, index)
         j = 1
         DO i = 2, 7
            IF (sublist(i) /= sublist(j)) THEN
               j = j + 1
               sublist(j) = sublist(i)
            END IF
         END DO
         maxcount = MAX(maxcount, j)
         sublist(j + 1:ultimate_max) = HUGE(sublist)
         list(1, :, icpu) = sublist
         list(2, :, icpu) = 0
      END DO

      CALL reallocate(list, 1, 2, 1, maxcount, 0, ncpu - 1)

      CALL timestop(handle)

   END SUBROUTINE create_destination_list

! **************************************************************************************************
!> \brief given a task list, compute the load of each process everywhere
!>        giving this function the ability to loop over a (sub)set of rs_grids,
!>        and do all the communication in one shot, would speed it up
!> \param loads ...
!> \param rs_descs ...
!> \param grid_level ...
!> \param ntasks ...
!> \param tasks ...
!> \param use_reordered_ranks ...
!> \par History
!>      none
!> \author MattW 21/11/2007
! **************************************************************************************************
   SUBROUTINE get_current_loads(loads, rs_descs, grid_level, ntasks, tasks, use_reordered_ranks)
      INTEGER(KIND=int_8), DIMENSION(:)                  :: loads
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER                                            :: grid_level, ntasks
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      LOGICAL, INTENT(IN)                                :: use_reordered_ranks

      CHARACTER(LEN=*), PARAMETER :: routineN = 'get_current_loads'

      INTEGER                                            :: handle, i, iatom, ilevel, img, ipgf, &
                                                            iset, jatom, jpgf, jset
      INTEGER(KIND=int_8)                                :: total_cost_local
      INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: recv_buf_i, send_buf_i
      TYPE(realspace_grid_desc_type), POINTER            :: desc

      CALL timeset(routineN, handle)

      desc => rs_descs(grid_level)%rs_desc

      ! allocate local arrays
      ALLOCATE (send_buf_i(desc%group_size))
      ALLOCATE (recv_buf_i(desc%group_size))

      ! communication step 1 : compute the total local cost of the tasks
      !                        each proc needs to know the amount of work he will receive

      ! send buffer now contains for each target the cost of the tasks it will receive
      send_buf_i = 0
      DO i = 1, ntasks
         ilevel = tasks(i)%grid_level
         img = tasks(i)%image
         iatom = tasks(i)%iatom
         jatom = tasks(i)%jatom
         iset = tasks(i)%iset
         jset = tasks(i)%jset
         ipgf = tasks(i)%ipgf
         jpgf = tasks(i)%jpgf
         IF (ilevel /= grid_level) CYCLE
         IF (use_reordered_ranks) THEN
            send_buf_i(rs_descs(ilevel)%rs_desc%virtual2real(decode_rank(tasks(i)%destination, SIZE(rs_descs))) + 1) = &
               send_buf_i(rs_descs(ilevel)%rs_desc%virtual2real(decode_rank(tasks(i)%destination, SIZE(rs_descs))) + 1) &
               + tasks(i)%cost
         ELSE
            send_buf_i(decode_rank(tasks(i)%destination, SIZE(rs_descs)) + 1) = &
               send_buf_i(decode_rank(tasks(i)%destination, SIZE(rs_descs)) + 1) &
               + tasks(i)%cost
         END IF
      END DO
      CALL desc%group%alltoall(send_buf_i, recv_buf_i, 1)

      ! communication step 2 : compute the global cost of the tasks
      total_cost_local = SUM(recv_buf_i)

      ! after this step, the recv buffer contains the local cost for each CPU
      CALL desc%group%allgather(total_cost_local, loads)

      CALL timestop(handle)

   END SUBROUTINE get_current_loads
! **************************************************************************************************
!> \brief performs load balancing shifting tasks on the replicated grids
!>        this modifies the destination of some of the tasks on replicated
!>        grids, and in this way balances the load
!> \param rs_descs ...
!> \param ntasks ...
!> \param tasks ...
!> \par History
!>      none
!> \author MattW 21/11/2007
! **************************************************************************************************
   SUBROUTINE load_balance_replicated(rs_descs, ntasks, tasks)

      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER                                            :: ntasks
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks

      CHARACTER(LEN=*), PARAMETER :: routineN = 'load_balance_replicated'

      INTEGER                                            :: handle, i, iatom, ilevel, img, ipgf, &
                                                            iset, j, jatom, jpgf, jset, &
                                                            no_overloaded, no_underloaded, &
                                                            proc_receiving
      INTEGER(KIND=int_8)                                :: average_cost, cost_task_rep, count, &
                                                            offset, total_cost_global
      INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: load_imbalance, loads, recv_buf_i
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: index
      TYPE(realspace_grid_desc_type), POINTER            :: desc

      CALL timeset(routineN, handle)

      desc => rs_descs(1)%rs_desc

      ! allocate local arrays
      ALLOCATE (recv_buf_i(desc%group_size))
      ALLOCATE (loads(desc%group_size))

      recv_buf_i = 0
      DO i = 1, SIZE(rs_descs)
         CALL get_current_loads(loads, rs_descs, i, ntasks, tasks, use_reordered_ranks=.TRUE.)
         recv_buf_i(:) = recv_buf_i + loads
      END DO

      total_cost_global = SUM(recv_buf_i)
      average_cost = total_cost_global/desc%group_size

      !
      ! compute how to redistribute the replicated tasks so that the average cost is reached
      !

      ! load imbalance measures the load of a given CPU relative
      ! to the optimal load distribution (load=average)
      ALLOCATE (load_imbalance(desc%group_size))
      ALLOCATE (INDEX(desc%group_size))

      load_imbalance(:) = recv_buf_i - average_cost
      no_overloaded = 0
      no_underloaded = 0

      DO i = 1, desc%group_size
         IF (load_imbalance(i) > 0) no_overloaded = no_overloaded + 1
         IF (load_imbalance(i) < 0) no_underloaded = no_underloaded + 1
      END DO

      ! sort the recv_buffer on number of tasks, gives us index which provides a
      ! mapping between processor ranks and how overloaded the processor
      CALL sort(recv_buf_i, SIZE(recv_buf_i), index)

      ! find out the number of replicated tasks each proc has
      ! but only those tasks which have not yet been assigned
      cost_task_rep = 0
      DO i = 1, ntasks
         IF (tasks(i)%dist_type == 0 &
             .AND. decode_rank(tasks(i)%destination, SIZE(rs_descs)) == decode_rank(tasks(i)%source, SIZE(rs_descs))) THEN
            cost_task_rep = cost_task_rep + tasks(i)%cost
         END IF
      END DO

      ! now, correct the load imbalance for the overloaded CPUs
      ! they will send away not more than the total load of replicated tasks
      CALL desc%group%allgather(cost_task_rep, recv_buf_i)

      DO i = 1, desc%group_size
         ! At the moment we can only offload replicated tasks
         IF (load_imbalance(i) > 0) &
            load_imbalance(i) = MIN(load_imbalance(i), recv_buf_i(i))
      END DO

      ! simplest algorithm I can think of of is that the processor with the most
      ! excess tasks fills up the process needing most, then moves on to next most.
      ! At the moment if we've got less replicated tasks than we're overloaded then
      ! task balancing will be incomplete

      ! only need to do anything if I've excess tasks
      IF (load_imbalance(desc%my_pos + 1) > 0) THEN

         count = 0 ! weighted amount of tasks offloaded
         offset = 0 ! no of underloaded processes already filled by other more overloaded procs

         ! calculate offset
         DO i = desc%group_size, desc%group_size - no_overloaded + 1, -1
            IF (INDEX(i) == desc%my_pos + 1) THEN
               EXIT
            ELSE
               offset = offset + load_imbalance(INDEX(i))
            END IF
         END DO

         ! find my starting processor to send to
         proc_receiving = HUGE(proc_receiving)
         DO i = 1, no_underloaded
            offset = offset + load_imbalance(INDEX(i))
            IF (offset <= 0) THEN
               proc_receiving = i
               EXIT
            END IF
         END DO

         ! offset now contains minus the number of tasks proc_receiving requires
         ! we fill this up by adjusting the destination of tasks on the replicated grid,
         ! then move to next most underloaded proc
         DO j = 1, ntasks
            IF (tasks(j)%dist_type == 0 &
                .AND. decode_rank(tasks(j)%destination, SIZE(rs_descs)) == decode_rank(tasks(j)%source, SIZE(rs_descs))) THEN
               ! just avoid sending to non existing procs due to integer truncation
               ! in the computation of the average
               IF (proc_receiving > no_underloaded) EXIT
               ! set new destination
               ilevel = tasks(j)%grid_level
               img = tasks(j)%image
               iatom = tasks(j)%iatom
               jatom = tasks(j)%jatom
               iset = tasks(j)%iset
               jset = tasks(j)%jset
               ipgf = tasks(j)%ipgf
               jpgf = tasks(j)%jpgf
               tasks(j)%destination = encode_rank(INDEX(proc_receiving) - 1, ilevel, SIZE(rs_descs))
               offset = offset + tasks(j)%cost
               count = count + tasks(j)%cost
               IF (count >= load_imbalance(desc%my_pos + 1)) EXIT
               IF (offset > 0) THEN
                  proc_receiving = proc_receiving + 1
                  ! just avoid sending to non existing procs due to integer truncation
                  ! in the computation of the average
                  IF (proc_receiving > no_underloaded) EXIT
                  offset = load_imbalance(INDEX(proc_receiving))
               END IF
            END IF
         END DO
      END IF

      DEALLOCATE (index)
      DEALLOCATE (load_imbalance)

      CALL timestop(handle)

   END SUBROUTINE load_balance_replicated

! **************************************************************************************************
!> \brief given an input task list, redistribute so that all tasks can be processed locally,
!>        i.e. dest equals rank
!> \param rs_descs ...
!> \param ntasks ...
!> \param tasks ...
!> \param ntasks_recv ...
!> \param tasks_recv ...
!> \par History
!>      none
!> \author MattW 21/11/2007
! **************************************************************************************************
   SUBROUTINE create_local_tasks(rs_descs, ntasks, tasks, ntasks_recv, tasks_recv)

      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER                                            :: ntasks
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER                                            :: ntasks_recv
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks_recv

      CHARACTER(LEN=*), PARAMETER :: routineN = 'create_local_tasks'

      INTEGER                                            :: handle, i, j, k, l, rank
      INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: recv_buf, send_buf
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: recv_disps, recv_sizes, send_disps, &
                                                            send_sizes
      TYPE(realspace_grid_desc_type), POINTER            :: desc

      CALL timeset(routineN, handle)

      desc => rs_descs(1)%rs_desc

      ! allocate local arrays
      ALLOCATE (send_sizes(desc%group_size))
      ALLOCATE (recv_sizes(desc%group_size))
      ALLOCATE (send_disps(desc%group_size))
      ALLOCATE (recv_disps(desc%group_size))
      ALLOCATE (send_buf(desc%group_size))
      ALLOCATE (recv_buf(desc%group_size))

      ! fill send buffer, now counting how many tasks will be send (stored in an int8 array for convenience only).
      send_buf = 0
      DO i = 1, ntasks
         rank = rs_descs(decode_level(tasks(i)%destination, SIZE(rs_descs))) &
                %rs_desc%virtual2real(decode_rank(tasks(i)%destination, SIZE(rs_descs)))
         send_buf(rank + 1) = send_buf(rank + 1) + 1
      END DO

      CALL desc%group%alltoall(send_buf, recv_buf, 1)

      ! pack the tasks, and send them around

      send_sizes = 0
      send_disps = 0
      recv_sizes = 0
      recv_disps = 0

      send_sizes(1) = INT(send_buf(1)*task_size_in_int8)
      recv_sizes(1) = INT(recv_buf(1)*task_size_in_int8)
      DO i = 2, desc%group_size
         send_sizes(i) = INT(send_buf(i)*task_size_in_int8)
         recv_sizes(i) = INT(recv_buf(i)*task_size_in_int8)
         send_disps(i) = send_disps(i - 1) + send_sizes(i - 1)
         recv_disps(i) = recv_disps(i - 1) + recv_sizes(i - 1)
      END DO

      ! deallocate old send/recv buffers
      DEALLOCATE (send_buf)
      DEALLOCATE (recv_buf)

      ! allocate them with new sizes
      ALLOCATE (send_buf(SUM(send_sizes)))
      ALLOCATE (recv_buf(SUM(recv_sizes)))

      ! do packing
      send_buf = 0
      send_sizes = 0
      DO j = 1, ntasks
         i = rs_descs(decode_level(tasks(j)%destination, SIZE(rs_descs))) &
             %rs_desc%virtual2real(decode_rank(tasks(j)%destination, SIZE(rs_descs))) + 1
         l = send_disps(i) + send_sizes(i)
         CALL serialize_task(tasks(j), send_buf(l + 1:l + task_size_in_int8))
         send_sizes(i) = send_sizes(i) + task_size_in_int8
      END DO

      ! do communication
      CALL desc%group%alltoall(send_buf, send_sizes, send_disps, recv_buf, recv_sizes, recv_disps)

      DEALLOCATE (send_buf)

      ntasks_recv = SUM(recv_sizes)/task_size_in_int8
      ALLOCATE (tasks_recv(ntasks_recv))

      ! do unpacking
      l = 0
      DO i = 1, desc%group_size
         DO j = 0, recv_sizes(i)/task_size_in_int8 - 1
            l = l + 1
            k = recv_disps(i) + j*task_size_in_int8
            CALL deserialize_task(tasks_recv(l), recv_buf(k + 1:k + task_size_in_int8))
         END DO
      END DO

      DEALLOCATE (recv_buf)
      DEALLOCATE (send_sizes)
      DEALLOCATE (recv_sizes)
      DEALLOCATE (send_disps)
      DEALLOCATE (recv_disps)

      CALL timestop(handle)

   END SUBROUTINE create_local_tasks

! **************************************************************************************************
!> \brief Assembles tasks to be performed on local grid
!> \param rs_descs the grids
!> \param ntasks Number of tasks for local processing
!> \param natoms ...
!> \param nimages ...
!> \param tasks the task set generated on this processor
!> \param rval ...
!> \param atom_pair_send ...
!> \param atom_pair_recv ...
!> \param symmetric ...
!> \param reorder_rs_grid_ranks ...
!> \param skip_load_balance_distributed ...
!> \par History
!>      none
!> \author MattW 21/11/2007
! **************************************************************************************************
   SUBROUTINE distribute_tasks(rs_descs, ntasks, natoms, &
                               tasks, atom_pair_send, atom_pair_recv, &
                               symmetric, reorder_rs_grid_ranks, skip_load_balance_distributed)

      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      INTEGER                                            :: ntasks, natoms
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      TYPE(atom_pair_type), DIMENSION(:), POINTER        :: atom_pair_send, atom_pair_recv
      LOGICAL, INTENT(IN)                                :: symmetric, reorder_rs_grid_ranks, &
                                                            skip_load_balance_distributed

      CHARACTER(LEN=*), PARAMETER :: routineN = 'distribute_tasks'

      INTEGER                                            :: handle, igrid_level, irank, ntasks_recv
      INTEGER(KIND=int_8)                                :: load_gap, max_load, replicated_load
      INTEGER(KIND=int_8), ALLOCATABLE, DIMENSION(:)     :: total_loads, total_loads_tmp, trial_loads
      INTEGER(KIND=int_8), DIMENSION(:, :), POINTER      :: loads
      INTEGER, ALLOCATABLE, DIMENSION(:)                 :: indices, real2virtual, total_index
      LOGICAL                                            :: distributed_grids, fixed_first_grid
      TYPE(realspace_grid_desc_type), POINTER            :: desc
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks_recv

      CALL timeset(routineN, handle)

      CPASSERT(ASSOCIATED(tasks))

      ! *** figure out if we have distributed grids
      distributed_grids = .FALSE.
      DO igrid_level = 1, SIZE(rs_descs)
         IF (rs_descs(igrid_level)%rs_desc%distributed) THEN
            distributed_grids = .TRUE.
         END IF
      END DO
      desc => rs_descs(1)%rs_desc

      IF (distributed_grids) THEN

         ALLOCATE (loads(0:desc%group_size - 1, SIZE(rs_descs)))
         ALLOCATE (total_loads(0:desc%group_size - 1))

         total_loads = 0

         ! First round of balancing on the distributed grids
         ! we just balance each of the distributed grids independently
         DO igrid_level = 1, SIZE(rs_descs)
            IF (rs_descs(igrid_level)%rs_desc%distributed) THEN

               IF (.NOT. skip_load_balance_distributed) &
                  CALL load_balance_distributed(tasks, ntasks, rs_descs, igrid_level, natoms)

               CALL get_current_loads(loads(:, igrid_level), rs_descs, igrid_level, ntasks, &
                                      tasks, use_reordered_ranks=.FALSE.)

               total_loads(:) = total_loads + loads(:, igrid_level)

            END IF
         END DO

         ! calculate the total load of replicated tasks, so we can decide if it is worth
         ! reordering the distributed grid levels

         replicated_load = 0
         DO igrid_level = 1, SIZE(rs_descs)
            IF (.NOT. rs_descs(igrid_level)%rs_desc%distributed) THEN
               CALL get_current_loads(loads(:, igrid_level), rs_descs, igrid_level, ntasks, &
                                      tasks, use_reordered_ranks=.FALSE.)
               replicated_load = replicated_load + SUM(loads(:, igrid_level))
            END IF
         END DO

         !IF (desc%my_pos==0) THEN
         ! WRITE(*,*) "Total replicated load is ",replicated_load
         !END IF

         ! Now we adjust the rank ordering based on the current loads
         ! we leave the first distributed level and all the replicated levels in the default order
         IF (reorder_rs_grid_ranks) THEN
            fixed_first_grid = .FALSE.
            DO igrid_level = 1, SIZE(rs_descs)
               IF (rs_descs(igrid_level)%rs_desc%distributed) THEN
                  IF (fixed_first_grid .EQV. .FALSE.) THEN
                     total_loads(:) = loads(:, igrid_level)
                     fixed_first_grid = .TRUE.
                  ELSE
                     ALLOCATE (trial_loads(0:desc%group_size - 1))

                     trial_loads(:) = total_loads + loads(:, igrid_level)
                     max_load = MAXVAL(trial_loads)
                     load_gap = 0
                     DO irank = 0, desc%group_size - 1
                        load_gap = load_gap + max_load - trial_loads(irank)
                     END DO

                     ! If there is not enough replicated load to load balance well enough
                     ! then we will reorder this grid level
                     IF (load_gap > replicated_load*1.05_dp) THEN

                        ALLOCATE (indices(0:desc%group_size - 1))
                        ALLOCATE (total_index(0:desc%group_size - 1))
                        ALLOCATE (total_loads_tmp(0:desc%group_size - 1))
                        ALLOCATE (real2virtual(0:desc%group_size - 1))

                        total_loads_tmp(:) = total_loads
                        CALL sort(total_loads_tmp, desc%group_size, total_index)
                        CALL sort(loads(:, igrid_level), desc%group_size, indices)

                        ! Reorder so that the rank with smallest load on this grid level is paired with
                        ! the highest load in total
                        DO irank = 0, desc%group_size - 1
                           total_loads(total_index(irank) - 1) = total_loads(total_index(irank) - 1) + &
                                                                 loads(desc%group_size - irank - 1, igrid_level)
                           real2virtual(total_index(irank) - 1) = indices(desc%group_size - irank - 1) - 1
                        END DO

                        CALL rs_grid_reorder_ranks(rs_descs(igrid_level)%rs_desc, real2virtual)

                        DEALLOCATE (indices)
                        DEALLOCATE (total_index)
                        DEALLOCATE (total_loads_tmp)
                        DEALLOCATE (real2virtual)
                     ELSE
                        total_loads(:) = trial_loads
                     END IF

                     DEALLOCATE (trial_loads)

                  END IF
               END IF
            END DO
         END IF

         ! Now we use the replicated tasks to balance out the rest of the load
         CALL load_balance_replicated(rs_descs, ntasks, tasks)

         !total_loads = 0
         !DO igrid_level=1,SIZE(rs_descs)
         !  CALL get_current_loads(loads(:,igrid_level), rs_descs, igrid_level, ntasks, &
         !                         tasks, use_reordered_ranks=.TRUE.)
         !  total_loads = total_loads + loads(:, igrid_level)
         !END DO

         !IF (desc%my_pos==0) THEN
         !  WRITE(*,*) ""
         !  WRITE(*,*) "At the end of the load balancing procedure"
         !  WRITE(*,*) "Maximum  load:",MAXVAL(total_loads)
         !  WRITE(*,*) "Average  load:",SUM(total_loads)/SIZE(total_loads)
         !  WRITE(*,*) "Minimum  load:",MINVAL(total_loads)
         !ENDIF

         ! given a list of tasks, this will do the needed reshuffle so that all tasks will be local
         CALL create_local_tasks(rs_descs, ntasks, tasks, ntasks_recv, tasks_recv)

         !
         ! tasks list are complete, we can compute the list of atomic blocks (atom pairs)
         ! we will be sending. These lists are needed for redistribute_matrix.
         !
         CALL get_atom_pair(atom_pair_send, tasks, ntasks=ntasks, send=.TRUE., symmetric=symmetric, rs_descs=rs_descs)

         ! natom_send=SIZE(atom_pair_send)
         ! CALL desc%group%sum(natom_send)
         ! IF (desc%my_pos==0) THEN
         !     WRITE(*,*) ""
         !     WRITE(*,*) "Total number of atomic blocks to be send:",natom_send
         ! ENDIF

         CALL get_atom_pair(atom_pair_recv, tasks_recv, ntasks=ntasks_recv, send=.FALSE., symmetric=symmetric, rs_descs=rs_descs)

         ! cleanup, at this point we  don't need the original tasks anymore
         DEALLOCATE (tasks)
         DEALLOCATE (loads)
         DEALLOCATE (total_loads)

      ELSE
         tasks_recv => tasks
         ntasks_recv = ntasks
         CALL get_atom_pair(atom_pair_recv, tasks_recv, ntasks=ntasks_recv, send=.FALSE., symmetric=symmetric, rs_descs=rs_descs)
         ! not distributed, hence atom_pair_send not needed
      END IF

      ! here we sort the task list we will process locally.
      ALLOCATE (indices(ntasks_recv))
      CALL tasks_sort(tasks_recv, ntasks_recv, indices)
      DEALLOCATE (indices)

      !
      ! final lists are ready
      !

      tasks => tasks_recv
      ntasks = ntasks_recv

      CALL timestop(handle)

   END SUBROUTINE distribute_tasks

! **************************************************************************************************
!> \brief ...
!> \param atom_pair ...
!> \param my_tasks ...
!> \param send ...
!> \param symmetric ...
!> \param natoms ...
!> \param nimages ...
!> \param rs_descs ...
! **************************************************************************************************
   SUBROUTINE get_atom_pair(atom_pair, tasks, ntasks, send, symmetric, rs_descs)

      TYPE(atom_pair_type), DIMENSION(:), POINTER        :: atom_pair
      TYPE(task_type), DIMENSION(:), INTENT(INOUT)       :: tasks
      INTEGER, INTENT(IN)                                :: ntasks
      LOGICAL, INTENT(IN)                                :: send, symmetric
      TYPE(realspace_grid_desc_p_type), DIMENSION(:), INTENT(IN) :: rs_descs

      INTEGER                                            :: i, ilevel, iatom, jatom, npairs, virt_rank
      INTEGER, DIMENSION(:), ALLOCATABLE                 :: indices
      TYPE(atom_pair_type), DIMENSION(:), ALLOCATABLE    :: atom_pair_tmp

      CPASSERT(.NOT. ASSOCIATED(atom_pair))
      IF (ntasks == 0) THEN
         ALLOCATE (atom_pair(0))
         RETURN
      END IF

      ! calculate list of atom pairs
      ! fill pair list taking into account symmetry
      ALLOCATE (atom_pair_tmp(ntasks))
      DO i = 1, ntasks
         atom_pair_tmp(i)%image = tasks(i)%image
         iatom = tasks(i)%iatom
         jatom = tasks(i)%jatom
         IF (symmetric .AND. iatom > jatom) THEN
            ! iatom / jatom swapped
            atom_pair_tmp(i)%row = jatom
            atom_pair_tmp(i)%col = iatom
         ELSE
            atom_pair_tmp(i)%row = iatom
            atom_pair_tmp(i)%col = jatom
         END IF

         IF (send) THEN
            ! If sending, we need to use the 'real rank' as the pair has to be sent to the process which
            ! actually has the correct part of the rs_grid to do the mapping
            ilevel = tasks(i)%grid_level
            virt_rank = decode_rank(tasks(i)%destination, SIZE(rs_descs))
            atom_pair_tmp(i)%rank = rs_descs(ilevel)%rs_desc%virtual2real(virt_rank)
         ELSE
            ! If we are receiving, then no conversion is needed as the rank is that of the process with the
            ! required matrix block, and the ordering of the rs grid is irrelevant
            atom_pair_tmp(i)%rank = decode_rank(tasks(i)%source, SIZE(rs_descs))
         END IF
      END DO

      ! find unique atom pairs that I'm sending/receiving
      ALLOCATE (indices(ntasks))
      CALL atom_pair_sort(atom_pair_tmp, ntasks, indices)
      npairs = 1
      tasks(indices(1))%pair_index = 1
      DO i = 2, ntasks
         IF (atom_pair_less_than(atom_pair_tmp(i - 1), atom_pair_tmp(i))) THEN
            npairs = npairs + 1
            atom_pair_tmp(npairs) = atom_pair_tmp(i)
         END IF
         tasks(indices(i))%pair_index = npairs
      END DO
      DEALLOCATE (indices)

      ! Copy unique pairs to final location.
      ALLOCATE (atom_pair(npairs))
      atom_pair(:) = atom_pair_tmp(:npairs)
      DEALLOCATE (atom_pair_tmp)

   END SUBROUTINE get_atom_pair

! **************************************************************************************************
!> \brief redistributes the matrix so that it can be used in realspace operations
!>        i.e. according to the task lists for collocate and integrate.
!>        This routine can become a bottleneck in large calculations.
!> \param rs_descs ...
!> \param pmats ...
!> \param atom_pair_send ...
!> \param atom_pair_recv ...
!> \param natoms ...
!> \param nimages ...
!> \param scatter ...
!> \param hmats ...
! **************************************************************************************************
   SUBROUTINE rs_distribute_matrix(rs_descs, pmats, atom_pair_send, atom_pair_recv, &
                                   nimages, scatter, hmats)

      TYPE(realspace_grid_desc_p_type), DIMENSION(:), &
         POINTER                                         :: rs_descs
      TYPE(dbcsr_p_type), DIMENSION(:), POINTER          :: pmats
      TYPE(atom_pair_type), DIMENSION(:), POINTER        :: atom_pair_send, atom_pair_recv
      INTEGER                                            :: nimages
      LOGICAL                                            :: scatter
      TYPE(dbcsr_p_type), DIMENSION(:), OPTIONAL, &
         POINTER                                         :: hmats

      CHARACTER(LEN=*), PARAMETER :: routineN = 'rs_distribute_matrix'

      INTEGER                                            :: acol, arow, handle, i, img, j, k, l, me, &
                                                            nblkcols_total, nblkrows_total, ncol, &
                                                            nrow, nthread, nthread_left
      INTEGER, ALLOCATABLE, DIMENSION(:) :: first_col, first_row, last_col, last_row, recv_disps, &
                                            recv_pair_count, recv_pair_disps, recv_sizes, send_disps, send_pair_count, &
                                            send_pair_disps, send_sizes
      INTEGER, DIMENSION(:), POINTER                     :: col_blk_size, row_blk_size
      LOGICAL                                            :: found
      REAL(KIND=dp), ALLOCATABLE, DIMENSION(:), TARGET   :: recv_buf_r, send_buf_r
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: h_block, p_block
      TYPE(dbcsr_type), POINTER                          :: hmat, pmat
      TYPE(realspace_grid_desc_type), POINTER            :: desc
      REAL(kind=dp), DIMENSION(:), POINTER                            :: vector

!$    INTEGER(kind=omp_lock_kind), ALLOCATABLE, DIMENSION(:) :: locks

      CALL timeset(routineN, handle)

      IF (.NOT. scatter) THEN
         CPASSERT(PRESENT(hmats))
      END IF

      desc => rs_descs(1)%rs_desc
      me = desc%my_pos + 1

      ! allocate local arrays
      ALLOCATE (send_sizes(desc%group_size))
      ALLOCATE (recv_sizes(desc%group_size))
      ALLOCATE (send_disps(desc%group_size))
      ALLOCATE (recv_disps(desc%group_size))
      ALLOCATE (send_pair_count(desc%group_size))
      ALLOCATE (recv_pair_count(desc%group_size))
      ALLOCATE (send_pair_disps(desc%group_size))
      ALLOCATE (recv_pair_disps(desc%group_size))

      pmat => pmats(1)%matrix
      CALL dbcsr_get_info(pmat, &
                          row_blk_size=row_blk_size, &
                          col_blk_size=col_blk_size, &
                          nblkrows_total=nblkrows_total, &
                          nblkcols_total=nblkcols_total)
      ALLOCATE (first_row(nblkrows_total), last_row(nblkrows_total), &
                first_col(nblkcols_total), last_col(nblkcols_total))
      CALL dbcsr_convert_sizes_to_offsets(row_blk_size, first_row, last_row)
      CALL dbcsr_convert_sizes_to_offsets(col_blk_size, first_col, last_col)

      ! set up send buffer sizes
      send_sizes = 0
      send_pair_count = 0
      DO i = 1, SIZE(atom_pair_send)
         k = atom_pair_send(i)%rank + 1 ! proc we're sending this block to
         arow = atom_pair_send(i)%row
         acol = atom_pair_send(i)%col
         nrow = last_row(arow) - first_row(arow) + 1
         ncol = last_col(acol) - first_col(acol) + 1
         send_sizes(k) = send_sizes(k) + nrow*ncol
         send_pair_count(k) = send_pair_count(k) + 1
      END DO

      send_disps = 0
      send_pair_disps = 0
      DO i = 2, desc%group_size
         send_disps(i) = send_disps(i - 1) + send_sizes(i - 1)
         send_pair_disps(i) = send_pair_disps(i - 1) + send_pair_count(i - 1)
      END DO

      ALLOCATE (send_buf_r(SUM(send_sizes)))

      ! set up recv buffer

      recv_sizes = 0
      recv_pair_count = 0
      DO i = 1, SIZE(atom_pair_recv)
         k = atom_pair_recv(i)%rank + 1 ! proc we're receiving this data from
         arow = atom_pair_recv(i)%row
         acol = atom_pair_recv(i)%col
         nrow = last_row(arow) - first_row(arow) + 1
         ncol = last_col(acol) - first_col(acol) + 1
         recv_sizes(k) = recv_sizes(k) + nrow*ncol
         recv_pair_count(k) = recv_pair_count(k) + 1
      END DO

      recv_disps = 0
      recv_pair_disps = 0
      DO i = 2, desc%group_size
         recv_disps(i) = recv_disps(i - 1) + recv_sizes(i - 1)
         recv_pair_disps(i) = recv_pair_disps(i - 1) + recv_pair_count(i - 1)
      END DO
      ALLOCATE (recv_buf_r(SUM(recv_sizes)))

!$OMP PARALLEL DEFAULT(OMP_DEFAULT_NONE_WITH_OOP) &
!$OMP          SHARED(desc,send_pair_count,send_pair_disps,nimages),&
!$OMP          SHARED(last_row,first_row,last_col,first_col),&
!$OMP          SHARED(pmats,send_buf_r,send_disps,send_sizes),&
!$OMP          SHARED(atom_pair_send,me,hmats,nblkrows_total),&
!$OMP          SHARED(atom_pair_recv,recv_buf_r,scatter,recv_pair_disps), &
!$OMP          SHARED(recv_sizes,recv_disps,recv_pair_count,locks), &
!$OMP          PRIVATE(i,img,arow,acol,nrow,ncol,p_block,found,j,k,l),&
!$OMP          PRIVATE(nthread,h_block,nthread_left,hmat,pmat,vector)

      nthread = 1
!$    nthread = omp_get_num_threads()
      nthread_left = 1
!$    nthread_left = MAX(1, nthread - 1)

      ! do packing
!$OMP DO schedule(guided)
      DO l = 1, desc%group_size
         IF (l == me) CYCLE
         send_sizes(l) = 0
         DO i = 1, send_pair_count(l)
            arow = atom_pair_send(send_pair_disps(l) + i)%row
            acol = atom_pair_send(send_pair_disps(l) + i)%col
            img = atom_pair_send(send_pair_disps(l) + i)%image
            nrow = last_row(arow) - first_row(arow) + 1
            ncol = last_col(acol) - first_col(acol) + 1
            pmat => pmats(img)%matrix
            CALL dbcsr_get_block_p(matrix=pmat, row=arow, col=acol, block=p_block, found=found)
            CPASSERT(found)

            DO k = 1, ncol
               DO j = 1, nrow
                  send_buf_r(send_disps(l) + send_sizes(l) + j + (k - 1)*nrow) = p_block(j, k)
               END DO
            END DO
            send_sizes(l) = send_sizes(l) + nrow*ncol
         END DO
      END DO
!$OMP END DO

      IF (.NOT. scatter) THEN
         ! We need locks to protect concurrent summation into H
!$OMP SINGLE
!$       ALLOCATE (locks(nthread*10))
!$OMP END SINGLE

!$OMP DO
!$       do i = 1, nthread*10
!$          call omp_init_lock(locks(i))
!$       end do
!$OMP END DO
      END IF

!$OMP MASTER
      ! do communication
      CALL desc%group%alltoall(send_buf_r, send_sizes, send_disps, &
                               recv_buf_r, recv_sizes, recv_disps)
!$OMP END MASTER

      ! If this is a scatter, then no need to copy local blocks,
      ! If not, we sum them directly into H (bypassing the alltoall)
      IF (.NOT. scatter) THEN

         ! Distribute work over remaining threads assuming one is still in the alltoall
!$OMP DO schedule(dynamic,MAX(1,send_pair_count(me)/nthread_left))
         DO i = 1, send_pair_count(me)
            arow = atom_pair_send(send_pair_disps(me) + i)%row
            acol = atom_pair_send(send_pair_disps(me) + i)%col
            img = atom_pair_send(send_pair_disps(me) + i)%image
            nrow = last_row(arow) - first_row(arow) + 1
            ncol = last_col(acol) - first_col(acol) + 1
            hmat => hmats(img)%matrix
            pmat => pmats(img)%matrix
            CALL dbcsr_get_block_p(matrix=hmat, row=arow, col=acol, BLOCK=h_block, found=found)
            CPASSERT(found)
            CALL dbcsr_get_block_p(matrix=pmat, row=arow, col=acol, BLOCK=p_block, found=found)
            CPASSERT(found)

!$          call omp_set_lock(locks((arow - 1)*nthread*10/nblkrows_total + 1))
            DO k = 1, ncol
               DO j = 1, nrow
                  h_block(j, k) = h_block(j, k) + p_block(j, k)
               END DO
            END DO
!$          call omp_unset_lock(locks((arow - 1)*nthread*10/nblkrows_total + 1))
         END DO
!$OMP END DO
      ELSE
         ! We will insert new blocks into P, so create mutable work matrices
         DO img = 1, nimages
            pmat => pmats(img)%matrix
            CALL dbcsr_work_create(pmat, work_mutable=.TRUE., &
                                   nblks_guess=SIZE(atom_pair_recv)/nthread, sizedata_guess=SIZE(recv_buf_r)/nthread, &
                                   n=nthread)
         END DO
      END IF

! wait for comm and setup to finish
!$OMP BARRIER

      !do unpacking
!$OMP DO schedule(guided)
      DO l = 1, desc%group_size
         IF (l == me) CYCLE
         recv_sizes(l) = 0
         DO i = 1, recv_pair_count(l)
            arow = atom_pair_recv(recv_pair_disps(l) + i)%row
            acol = atom_pair_recv(recv_pair_disps(l) + i)%col
            img = atom_pair_recv(recv_pair_disps(l) + i)%image
            nrow = last_row(arow) - first_row(arow) + 1
            ncol = last_col(acol) - first_col(acol) + 1
            pmat => pmats(img)%matrix
            NULLIFY (p_block)
            CALL dbcsr_get_block_p(matrix=pmat, row=arow, col=acol, BLOCK=p_block, found=found)

            IF (PRESENT(hmats)) THEN
               hmat => hmats(img)%matrix
               CALL dbcsr_get_block_p(matrix=hmat, row=arow, col=acol, BLOCK=h_block, found=found)
               CPASSERT(found)
            END IF

            IF (scatter .AND. .NOT. ASSOCIATED(p_block)) THEN
               vector => recv_buf_r(recv_disps(l) + recv_sizes(l) + 1:recv_disps(l) + recv_sizes(l) + nrow*ncol)
               CALL dbcsr_put_block(pmat, arow, acol, block=RESHAPE(vector, [nrow, ncol]))
            END IF
            IF (.NOT. scatter) THEN
!$             call omp_set_lock(locks((arow - 1)*nthread*10/nblkrows_total + 1))
               DO k = 1, ncol
                  DO j = 1, nrow
                     h_block(j, k) = h_block(j, k) + recv_buf_r(recv_disps(l) + recv_sizes(l) + j + (k - 1)*nrow)
                  END DO
               END DO
!$             call omp_unset_lock(locks((arow - 1)*nthread*10/nblkrows_total + 1))
            END IF
            recv_sizes(l) = recv_sizes(l) + nrow*ncol
         END DO
      END DO
!$OMP END DO

!$    IF (.not. scatter) THEN
!$OMP DO
!$       do i = 1, nthread*10
!$          call omp_destroy_lock(locks(i))
!$       end do
!$OMP END DO
!$    END IF

!$OMP SINGLE
!$    IF (.not. scatter) THEN
!$       DEALLOCATE (locks)
!$    END IF
!$OMP END SINGLE NOWAIT

      IF (scatter) THEN
         ! Blocks were added to P
         DO img = 1, nimages
            pmat => pmats(img)%matrix
            CALL dbcsr_finalize(pmat)
         END DO
      END IF
!$OMP END PARALLEL

      DEALLOCATE (send_buf_r)
      DEALLOCATE (recv_buf_r)

      DEALLOCATE (send_sizes)
      DEALLOCATE (recv_sizes)
      DEALLOCATE (send_disps)
      DEALLOCATE (recv_disps)
      DEALLOCATE (send_pair_count)
      DEALLOCATE (recv_pair_count)
      DEALLOCATE (send_pair_disps)
      DEALLOCATE (recv_pair_disps)

      DEALLOCATE (first_row, last_row, first_col, last_col)

      CALL timestop(handle)

   END SUBROUTINE rs_distribute_matrix

! **************************************************************************************************
!> \brief Calculates offsets and sizes for rs_scatter_matrix and rs_copy_matrix.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_calc_offsets(pairs, nsgf, group_size, &
                              pair_offsets, rank_offsets, rank_sizes, buffer_size)
      TYPE(atom_pair_type), DIMENSION(:), INTENT(IN)     :: pairs
      INTEGER, DIMENSION(:), INTENT(IN)                  :: nsgf
      INTEGER, INTENT(IN)                                :: group_size
      INTEGER, DIMENSION(:), POINTER                     :: pair_offsets, rank_offsets, rank_sizes
      INTEGER, INTENT(INOUT)                             :: buffer_size

      INTEGER                                            :: acol, arow, i, block_size, total_size, k, prev_k

      IF (ASSOCIATED(pair_offsets)) DEALLOCATE (pair_offsets)
      IF (ASSOCIATED(rank_offsets)) DEALLOCATE (rank_offsets)
      IF (ASSOCIATED(rank_sizes)) DEALLOCATE (rank_sizes)

      ! calculate buffer_size and pair_offsets
      ALLOCATE (pair_offsets(SIZE(pairs)))
      total_size = 0
      DO i = 1, SIZE(pairs)
         pair_offsets(i) = total_size
         arow = pairs(i)%row
         acol = pairs(i)%col
         block_size = nsgf(arow)*nsgf(acol)
         total_size = total_size + block_size
      END DO
      buffer_size = total_size

      ! calculate rank_offsets and rank_sizes
      ALLOCATE (rank_offsets(group_size))
      ALLOCATE (rank_sizes(group_size))
      rank_offsets = 0
      rank_sizes = 0
      IF (SIZE(pairs) > 0) THEN
         prev_k = pairs(1)%rank + 1
         DO i = 1, SIZE(pairs)
            k = pairs(i)%rank + 1
            CPASSERT(k >= prev_k) ! expecting the pairs to be ordered by rank
            IF (k > prev_k) THEN
               rank_offsets(k) = pair_offsets(i)
               rank_sizes(prev_k) = rank_offsets(k) - rank_offsets(prev_k)
               prev_k = k
            END IF
         END DO
         rank_sizes(k) = buffer_size - rank_offsets(k) ! complete last rank
      END IF

   END SUBROUTINE rs_calc_offsets

! **************************************************************************************************
!> \brief Scatters dbcsr matrix blocks and receives them into a buffer as needed before collocation.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_scatter_matrices(src_matrices, dest_buffer, task_list, group)
      TYPE(dbcsr_p_type), DIMENSION(:), INTENT(IN)       :: src_matrices
      TYPE(offload_buffer_type), INTENT(INOUT)           :: dest_buffer
      TYPE(task_list_type), INTENT(IN)                   :: task_list
      TYPE(mp_comm_type), INTENT(IN)                     :: group

      CHARACTER(LEN=*), PARAMETER :: routineN = 'rs_scatter_matrices'

      INTEGER                                            :: handle
      REAL(KIND=dp), DIMENSION(:), ALLOCATABLE           :: buffer_send

      CALL timeset(routineN, handle)
      ALLOCATE (buffer_send(task_list%buffer_size_send))

      ! pack dbcsr blocks into send buffer
      CPASSERT(ASSOCIATED(task_list%atom_pair_send))
      CALL rs_pack_buffer(src_matrices=src_matrices, &
                          dest_buffer=buffer_send, &
                          atom_pair=task_list%atom_pair_send, &
                          pair_offsets=task_list%pair_offsets_send)

      ! mpi all-to-all communication, receiving directly into blocks_recv%buffer.
      CALL group%alltoall(buffer_send, task_list%rank_sizes_send, task_list%rank_offsets_send, &
                          dest_buffer%host_buffer, &
                          task_list%rank_sizes_recv, task_list%rank_offsets_recv)

      DEALLOCATE (buffer_send)
      CALL timestop(handle)

   END SUBROUTINE rs_scatter_matrices

! **************************************************************************************************
!> \brief Gather the dbcsr matrix blocks and receives them into a buffer as needed after integration.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_gather_matrices(src_buffer, dest_matrices, task_list, group)
      TYPE(offload_buffer_type), INTENT(IN)              :: src_buffer
      TYPE(dbcsr_p_type), DIMENSION(:), INTENT(INOUT)    :: dest_matrices
      TYPE(task_list_type), INTENT(IN)                   :: task_list
      TYPE(mp_comm_type), INTENT(IN)                     :: group

      CHARACTER(LEN=*), PARAMETER :: routineN = 'rs_gather_matrices'

      INTEGER                                            :: handle
      REAL(KIND=dp), DIMENSION(:), ALLOCATABLE           :: buffer_send

      CALL timeset(routineN, handle)

      ! Caution: The meaning of send and recv are reversed in this routine.
      ALLOCATE (buffer_send(task_list%buffer_size_send)) ! e.g. this is actually used for receiving

      ! mpi all-to-all communication
      CALL group%alltoall(src_buffer%host_buffer, task_list%rank_sizes_recv, task_list%rank_offsets_recv, &
                          buffer_send, task_list%rank_sizes_send, task_list%rank_offsets_send)

      ! unpack dbcsr blocks from send buffer
      CPASSERT(ASSOCIATED(task_list%atom_pair_send))
      CALL rs_unpack_buffer(src_buffer=buffer_send, &
                            dest_matrices=dest_matrices, &
                            atom_pair=task_list%atom_pair_send, &
                            pair_offsets=task_list%pair_offsets_send)

      DEALLOCATE (buffer_send)
      CALL timestop(handle)

   END SUBROUTINE rs_gather_matrices

! **************************************************************************************************
!> \brief Copies the DBCSR blocks into buffer, replaces rs_scatter_matrix for non-distributed grids.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_copy_to_buffer(src_matrices, dest_buffer, task_list)
      TYPE(dbcsr_p_type), DIMENSION(:), INTENT(IN)       :: src_matrices
      TYPE(offload_buffer_type), INTENT(INOUT)           :: dest_buffer
      TYPE(task_list_type), INTENT(IN)                   :: task_list

      CALL rs_pack_buffer(src_matrices=src_matrices, &
                          dest_buffer=dest_buffer%host_buffer, &
                          atom_pair=task_list%atom_pair_recv, &
                          pair_offsets=task_list%pair_offsets_recv)

   END SUBROUTINE rs_copy_to_buffer

! **************************************************************************************************
!> \brief Copies from buffer into DBCSR matrics, replaces rs_gather_matrix for non-distributed grids.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_copy_to_matrices(src_buffer, dest_matrices, task_list)
      TYPE(offload_buffer_type), INTENT(IN)              :: src_buffer
      TYPE(dbcsr_p_type), DIMENSION(:), INTENT(INOUT)    :: dest_matrices
      TYPE(task_list_type), INTENT(IN)                   :: task_list

      CALL rs_unpack_buffer(src_buffer=src_buffer%host_buffer, &
                            dest_matrices=dest_matrices, &
                            atom_pair=task_list%atom_pair_recv, &
                            pair_offsets=task_list%pair_offsets_recv)

   END SUBROUTINE rs_copy_to_matrices

! **************************************************************************************************
!> \brief Helper routine for rs_scatter_matrix and rs_copy_to_buffer.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_pack_buffer(src_matrices, dest_buffer, atom_pair, pair_offsets)
      TYPE(dbcsr_p_type), DIMENSION(:), INTENT(IN)       :: src_matrices
      REAL(KIND=dp), DIMENSION(:), INTENT(INOUT)         :: dest_buffer
      TYPE(atom_pair_type), DIMENSION(:), INTENT(IN)     :: atom_pair
      INTEGER, DIMENSION(:), INTENT(IN)                  :: pair_offsets

      INTEGER                                            :: acol, arow, img, i, offset, block_size
      LOGICAL                                            :: found
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: block

!$OMP PARALLEL DEFAULT(NONE), &
!$OMP          SHARED(src_matrices,atom_pair,pair_offsets,dest_buffer), &
!$OMP          PRIVATE(acol,arow,img,i,offset,block_size,found,block)
!$OMP DO schedule(guided)
      DO i = 1, SIZE(atom_pair)
         arow = atom_pair(i)%row
         acol = atom_pair(i)%col
         img = atom_pair(i)%image
         CALL dbcsr_get_block_p(matrix=src_matrices(img)%matrix, row=arow, col=acol, &
                                block=block, found=found)
         CPASSERT(found)
         block_size = SIZE(block)
         offset = pair_offsets(i)
         dest_buffer(offset + 1:offset + block_size) = RESHAPE(block, shape=[block_size])
      END DO
!$OMP END DO
!$OMP END PARALLEL

   END SUBROUTINE rs_pack_buffer

! **************************************************************************************************
!> \brief Helper routine for rs_gather_matrix and rs_copy_to_matrices.
!> \author Ole Schuett
! **************************************************************************************************
   SUBROUTINE rs_unpack_buffer(src_buffer, dest_matrices, atom_pair, pair_offsets)
      REAL(KIND=dp), DIMENSION(:), INTENT(IN)            :: src_buffer
      TYPE(dbcsr_p_type), DIMENSION(:), INTENT(INOUT)    :: dest_matrices
      TYPE(atom_pair_type), DIMENSION(:), INTENT(IN)     :: atom_pair
      INTEGER, DIMENSION(:), INTENT(IN)                  :: pair_offsets

      INTEGER                                            :: acol, arow, img, i, offset, &
                                                            nrows, ncols, lock_num
      LOGICAL                                            :: found
      REAL(KIND=dp), DIMENSION(:, :), POINTER            :: block
      INTEGER(kind=omp_lock_kind), ALLOCATABLE, DIMENSION(:) :: locks

      ! initialize locks
      ALLOCATE (locks(10*omp_get_max_threads()))
      DO i = 1, SIZE(locks)
         CALL omp_init_lock(locks(i))
      END DO

!$OMP PARALLEL DEFAULT(NONE), &
!$OMP          SHARED(src_buffer,atom_pair,pair_offsets,dest_matrices,locks), &
!$OMP          PRIVATE(acol,arow,img,i,offset,nrows,ncols,lock_num,found,block)
!$OMP DO schedule(guided)
      DO i = 1, SIZE(atom_pair)
         arow = atom_pair(i)%row
         acol = atom_pair(i)%col
         img = atom_pair(i)%image
         CALL dbcsr_get_block_p(matrix=dest_matrices(img)%matrix, row=arow, col=acol, &
                                block=block, found=found)
         CPASSERT(found)
         nrows = SIZE(block, 1)
         ncols = SIZE(block, 2)
         offset = pair_offsets(i)
         lock_num = MODULO(arow, SIZE(locks)) + 1  ! map matrix rows round-robin to available locks

         CALL omp_set_lock(locks(lock_num))
         block = block + RESHAPE(src_buffer(offset + 1:offset + nrows*ncols), shape=[nrows, ncols])
         CALL omp_unset_lock(locks(lock_num))
      END DO
!$OMP END DO
!$OMP END PARALLEL

      ! destroy locks
      DO i = 1, SIZE(locks)
         CALL omp_destroy_lock(locks(i))
      END DO
      DEALLOCATE (locks)

   END SUBROUTINE rs_unpack_buffer

! **************************************************************************************************
!> \brief determines the rank of the processor who's real rs grid contains point
!> \param rs_desc ...
!> \param igrid_level ...
!> \param n_levels ...
!> \param cube_center ...
!> \param ntasks ...
!> \param tasks ...
!> \param lb_cube ...
!> \param ub_cube ...
!> \param added_tasks ...
!> \par History
!>      11.2007 created [MattW]
!>      10.2008 rewritten [Joost VandeVondele]
!> \author MattW
! **************************************************************************************************
   SUBROUTINE rs_find_node(rs_desc, igrid_level, n_levels, cube_center, ntasks, tasks, &
                           lb_cube, ub_cube, added_tasks)

      TYPE(realspace_grid_desc_type), POINTER            :: rs_desc
      INTEGER, INTENT(IN)                                :: igrid_level, n_levels
      INTEGER, DIMENSION(3), INTENT(IN)                  :: cube_center
      INTEGER, INTENT(INOUT)                             :: ntasks
      TYPE(task_type), DIMENSION(:), POINTER             :: tasks
      INTEGER, DIMENSION(3), INTENT(IN)                  :: lb_cube, ub_cube
      INTEGER, INTENT(OUT)                               :: added_tasks

      INTEGER, PARAMETER                                 :: add_tasks = 1000
      REAL(kind=dp), PARAMETER                           :: mult_tasks = 2.0_dp

      INTEGER :: bit_index, coord(3), curr_tasks, dest, i, icoord(3), idest, itask, ix, iy, iz, &
                 lb_coord(3), lb_domain(3), lbc(3), ub_coord(3), ub_domain(3), ubc(3)
      INTEGER                                            :: bit_pattern
      LOGICAL                                            :: dir_periodic(3)

      coord(1) = rs_desc%x2coord(cube_center(1))
      coord(2) = rs_desc%y2coord(cube_center(2))
      coord(3) = rs_desc%z2coord(cube_center(3))
      dest = rs_desc%coord2rank(coord(1), coord(2), coord(3))

      ! the real cube coordinates
      lbc = lb_cube + cube_center
      ubc = ub_cube + cube_center

      IF (ALL((rs_desc%lb_global(:, dest) - rs_desc%border) <= lbc) .AND. &
          ALL((rs_desc%ub_global(:, dest) + rs_desc%border) >= ubc)) THEN
         !standard distributed collocation/integration
         tasks(ntasks)%destination = encode_rank(dest, igrid_level, n_levels)
         tasks(ntasks)%dist_type = 1
         tasks(ntasks)%subpatch_pattern = 0
         added_tasks = 1

         ! here we figure out if there is an alternate location for this task
         ! i.e. even though the cube_center is not in the real local domain,
         ! it might fully fit in the halo of the neighbor
         ! if its radius is smaller than the maximum radius
         ! the list of possible neighbors is stored here as a bit pattern
         ! to reconstruct what the rank of the neigbor is,
         ! we can use (note this requires the correct rs_grid)
         !  IF (BTEST(bit_pattern,0)) rank=rs_grid_locate_rank(rs_desc,tasks(ntasks)%destination,[-1,0,0])
         !  IF (BTEST(bit_pattern,1)) rank=rs_grid_locate_rank(rs_desc,tasks(ntasks)%destination,[+1,0,0])
         !  IF (BTEST(bit_pattern,2)) rank=rs_grid_locate_rank(rs_desc,tasks(ntasks)%destination,[0,-1,0])
         !  IF (BTEST(bit_pattern,3)) rank=rs_grid_locate_rank(rs_desc,tasks(ntasks)%destination,[0,+1,0])
         !  IF (BTEST(bit_pattern,4)) rank=rs_grid_locate_rank(rs_desc,tasks(ntasks)%destination,[0,0,-1])
         !  IF (BTEST(bit_pattern,5)) rank=rs_grid_locate_rank(rs_desc,tasks(ntasks)%destination,[0,0,+1])
         bit_index = 0
         bit_pattern = 0
         DO i = 1, 3
            IF (rs_desc%perd(i) == 1) THEN
               bit_pattern = IBCLR(bit_pattern, bit_index)
               bit_index = bit_index + 1
               bit_pattern = IBCLR(bit_pattern, bit_index)
               bit_index = bit_index + 1
            ELSE
               ! fits the left neighbor ?
               IF (ubc(i) <= rs_desc%lb_global(i, dest) - 1 + rs_desc%border) THEN
                  bit_pattern = IBSET(bit_pattern, bit_index)
                  bit_index = bit_index + 1
               ELSE
                  bit_pattern = IBCLR(bit_pattern, bit_index)
                  bit_index = bit_index + 1
               END IF
               ! fits the right neighbor ?
               IF (lbc(i) >= rs_desc%ub_global(i, dest) + 1 - rs_desc%border) THEN
                  bit_pattern = IBSET(bit_pattern, bit_index)
                  bit_index = bit_index + 1
               ELSE
                  bit_pattern = IBCLR(bit_pattern, bit_index)
                  bit_index = bit_index + 1
               END IF
            END IF
         END DO
         tasks(ntasks)%subpatch_pattern = bit_pattern

      ELSE
         ! generalised collocation/integration needed
         ! first we figure out how many neighbors we have to add to include the lbc/ubc
         ! in the available domains (inclusive of halo points)
         ! first we 'ignore' periodic boundary conditions
         ! i.e. ub_coord-lb_coord+1 might be larger than group_dim
         lb_coord = coord
         ub_coord = coord
         lb_domain = rs_desc%lb_global(:, dest) - rs_desc%border
         ub_domain = rs_desc%ub_global(:, dest) + rs_desc%border
         DO i = 1, 3
            ! only if the grid is not periodic in this direction we need to take care of adding neighbors
            IF (rs_desc%perd(i) == 0) THEN
               ! if the domain lower bound is greater than the lbc we need to add the size of the neighbor domain
               DO
                  IF (lb_domain(i) > lbc(i)) THEN
                     lb_coord(i) = lb_coord(i) - 1
                     icoord = MODULO(lb_coord, rs_desc%group_dim)
                     idest = rs_desc%coord2rank(icoord(1), icoord(2), icoord(3))
                     lb_domain(i) = lb_domain(i) - (rs_desc%ub_global(i, idest) - rs_desc%lb_global(i, idest) + 1)
                  ELSE
                     EXIT
                  END IF
               END DO
               ! same for the upper bound
               DO
                  IF (ub_domain(i) < ubc(i)) THEN
                     ub_coord(i) = ub_coord(i) + 1
                     icoord = MODULO(ub_coord, rs_desc%group_dim)
                     idest = rs_desc%coord2rank(icoord(1), icoord(2), icoord(3))
                     ub_domain(i) = ub_domain(i) + (rs_desc%ub_global(i, idest) - rs_desc%lb_global(i, idest) + 1)
                  ELSE
                     EXIT
                  END IF
               END DO
            END IF
         END DO

         ! some care is needed for the periodic boundaries
         DO i = 1, 3
            IF (ub_domain(i) - lb_domain(i) + 1 >= rs_desc%npts(i)) THEN
               dir_periodic(i) = .TRUE.
               lb_coord(i) = 0
               ub_coord(i) = rs_desc%group_dim(i) - 1
            ELSE
               dir_periodic(i) = .FALSE.
            END IF
         END DO

         added_tasks = PRODUCT(ub_coord - lb_coord + 1)
         itask = ntasks
         ntasks = ntasks + added_tasks - 1
         IF (ntasks > SIZE(tasks)) THEN
            curr_tasks = INT((SIZE(tasks) + add_tasks)*mult_tasks)
            CALL reallocate_tasks(tasks, curr_tasks)
         END IF
         DO iz = lb_coord(3), ub_coord(3)
         DO iy = lb_coord(2), ub_coord(2)
         DO ix = lb_coord(1), ub_coord(1)
            icoord = MODULO([ix, iy, iz], rs_desc%group_dim)
            idest = rs_desc%coord2rank(icoord(1), icoord(2), icoord(3))
            tasks(itask)%destination = encode_rank(idest, igrid_level, n_levels)
            tasks(itask)%dist_type = 2
            tasks(itask)%subpatch_pattern = 0
            ! encode the domain size for this task
            ! if the bit is set, we need to add the border in that direction
            IF (ix == lb_coord(1) .AND. .NOT. dir_periodic(1)) &
               tasks(itask)%subpatch_pattern = IBSET(tasks(itask)%subpatch_pattern, 0)
            IF (ix == ub_coord(1) .AND. .NOT. dir_periodic(1)) &
               tasks(itask)%subpatch_pattern = IBSET(tasks(itask)%subpatch_pattern, 1)
            IF (iy == lb_coord(2) .AND. .NOT. dir_periodic(2)) &
               tasks(itask)%subpatch_pattern = IBSET(tasks(itask)%subpatch_pattern, 2)
            IF (iy == ub_coord(2) .AND. .NOT. dir_periodic(2)) &
               tasks(itask)%subpatch_pattern = IBSET(tasks(itask)%subpatch_pattern, 3)
            IF (iz == lb_coord(3) .AND. .NOT. dir_periodic(3)) &
               tasks(itask)%subpatch_pattern = IBSET(tasks(itask)%subpatch_pattern, 4)
            IF (iz == ub_coord(3) .AND. .NOT. dir_periodic(3)) &
               tasks(itask)%subpatch_pattern = IBSET(tasks(itask)%subpatch_pattern, 5)
            itask = itask + 1
         END DO
         END DO
         END DO
      END IF

   END SUBROUTINE rs_find_node

! **************************************************************************************************
!> \brief utility functions for encoding the grid level with a rank, allowing
!>        different grid levels to maintain a different rank ordering without
!>        losing information.  These encoded_ints are stored in the tasks(1:2,:) array
!> \param rank ...
!> \param grid_level ...
!> \param n_levels ...
!> \return ...
!> \par History
!>      4.2009 created [Iain Bethune]
!>        (c) The Numerical Algorithms Group (NAG) Ltd, 2009 on behalf of the HECToR project
! **************************************************************************************************
   FUNCTION encode_rank(rank, grid_level, n_levels) RESULT(encoded_int)

      INTEGER, INTENT(IN)                                :: rank, grid_level, n_levels
      INTEGER                                            :: encoded_int

! ordered so can still sort by rank

      encoded_int = rank*n_levels + grid_level - 1

   END FUNCTION encode_rank

! **************************************************************************************************
!> \brief ...
!> \param encoded_int ...
!> \param n_levels ...
!> \return ...
! **************************************************************************************************
   FUNCTION decode_rank(encoded_int, n_levels) RESULT(rank)

      INTEGER, INTENT(IN)                                :: encoded_int
      INTEGER, INTENT(IN)                                :: n_levels
      INTEGER                                            :: rank

      rank = INT(encoded_int/n_levels)

   END FUNCTION decode_rank

! **************************************************************************************************
!> \brief ...
!> \param encoded_int ...
!> \param n_levels ...
!> \return ...
! **************************************************************************************************
   FUNCTION decode_level(encoded_int, n_levels) RESULT(grid_level)

      INTEGER, INTENT(IN)                                :: encoded_int
      INTEGER, INTENT(IN)                                :: n_levels
      INTEGER                                            :: grid_level

      grid_level = INT(MODULO(encoded_int, n_levels)) + 1

   END FUNCTION decode_level

! **************************************************************************************************
!> \brief Sort pgf index pair (ipgf,iset,iatom),(jpgf,jset,jatom) for which all atom pairs are
!>        grouped, and for each atom pair all set pairs are grouped and for each set pair,
!>        all pgfs are grouped.
!>        This yields the right order of the tasks for collocation after the sort
!>        in distribute_tasks. E.g. for a atom pair, all sets and pgfs are computed in one go.
!>        The exception is the gridlevel. Tasks are first ordered wrt to grid_level. This implies
!>        that a given density matrix block will be decontracted several times,
!>        but cache effects on the grid make up for this.
!> \param a ...
!> \param b ...
!> \return ...
!> \author Ole Schuett
! **************************************************************************************************
   PURE FUNCTION tasks_less_than(a, b) RESULT(res)
      TYPE(task_type), INTENT(IN)                        :: a, b
      LOGICAL                                            :: res

      IF (a%grid_level /= b%grid_level) THEN
         res = a%grid_level < b%grid_level

      ELSE IF (a%image /= b%image) THEN
         res = a%image < b%image

      ELSE IF (a%iatom /= b%iatom) THEN
         res = a%iatom < b%iatom

      ELSE IF (a%jatom /= b%jatom) THEN
         res = a%jatom < b%jatom

      ELSE IF (a%iset /= b%iset) THEN
         res = a%iset < b%iset

      ELSE IF (a%jset /= b%jset) THEN
         res = a%jset < b%jset

      ELSE IF (a%ipgf /= b%ipgf) THEN
         res = a%ipgf < b%ipgf

      ELSE
         res = a%jpgf < b%jpgf

      END IF
   END FUNCTION tasks_less_than

   #:call array_sort(prefix='tasks', type='TYPE(task_type)')
   #:endcall

! **************************************************************************************************
!> \brief Order atom pairs to find duplicates.
!> \param a ...
!> \param b ...
!> \return ...
!> \author Ole Schuett
! **************************************************************************************************
   PURE FUNCTION atom_pair_less_than(a, b) RESULT(res)
      TYPE(atom_pair_type), INTENT(IN)                   :: a, b
      LOGICAL                                            :: res

      IF (a%rank /= b%rank) THEN
         res = a%rank < b%rank

      ELSE IF (a%row /= b%row) THEN
         res = a%row < b%row

      ELSE IF (a%col /= b%col) THEN
         res = a%col < b%col

      ELSE
         res = a%image < b%image

      END IF
   END FUNCTION atom_pair_less_than

   #:call array_sort(prefix='atom_pair', type='TYPE(atom_pair_type)')
   #:endcall

END MODULE task_list_methods
